Compare commits

..

7 Commits

Author SHA1 Message Date
crn4
5e2d938a9c Merge branch 'main' into nmap/cleanup 2026-04-09 10:42:42 +02:00
crn4
da41a94ead Merge branch 'main' into nmap/cleanup 2026-04-03 19:00:46 +02:00
crn4
74a9842866 removed experimental cached network map code 2026-03-12 16:39:58 +01:00
crn4
86de658b3a Merge branch 'main' into nmap/cleanup 2026-03-12 15:40:43 +01:00
crn4
5e097beec0 linter 2026-03-10 14:26:00 +01:00
crn4
710f6a9b76 add more tests 2026-03-10 13:26:12 +01:00
crn4
30dbf21735 [management] removed legacy network map code 2026-03-10 12:32:06 +01:00
194 changed files with 2614 additions and 13759 deletions

View File

@@ -1,62 +0,0 @@
name: Proto Version Check
on:
pull_request:
paths:
- "**/*.pb.go"
jobs:
check-proto-versions:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
per_page: 100,
});
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(
`Proto version strings changed in generated files.\n` +
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
`Regenerate with the matching tool versions.\n\n` +
details
);
return;
}
console.log('No proto version string changes detected');

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.2"
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)

View File

@@ -8,7 +8,6 @@ import (
"os"
"slices"
"sync"
"time"
"golang.org/x/exp/maps"
@@ -16,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -28,7 +26,6 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -71,30 +68,7 @@ type Client struct {
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
}
// NewClient instantiate a new Client
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
}
func (c *Client) RenewTun(fd int) error {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := cc.Engine()
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd)
}
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
}
func (c *Client) Networks() *NetworkArray {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := cc.Engine()
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.getConnectClient()
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -7,5 +7,4 @@ package android
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
CacheDir() string
}

View File

@@ -75,7 +75,6 @@ var (
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
networksDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",

View File

@@ -44,13 +44,10 @@ func init() {
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
`Use --service-env "" to clear all saved env vars. ` +
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
}
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}

View File

@@ -59,10 +59,6 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings")
}
if networksDisabled {
args = append(args, "--disable-networks")
}
return args
}

View File

@@ -28,7 +28,6 @@ type serviceParams struct {
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -79,12 +78,11 @@ func currentServiceParams() *serviceParams {
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
DisableNetworks: networksDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil {
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
@@ -144,46 +142,31 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
updateSettingsDisabled = params.DisableUpdateSettings
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set with values, explicit values win on key
// conflict but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set to empty, all saved env vars are cleared.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if !cmd.Flags().Changed("service-env") {
if len(params.ServiceEnvVars) > 0 {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Flag was explicitly set: parse what the user provided.
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
// If the user passed an empty value (e.g. --service-env ""), clear all
// saved env vars rather than merging.
if len(explicit) == 0 {
serviceEnvVars = nil
return
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Merge saved values underneath explicit ones.
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict

View File

@@ -327,41 +327,6 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", ""))
saved := &serviceParams{
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
}
applyServiceEnvParams(cmd, saved)
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
}
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
params := currentServiceParams()
// After parsing, the empty string is skipped, resulting in an empty map.
// The map should still be set (not nil) so it overwrites saved values.
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
@@ -535,7 +500,6 @@ func fieldToGlobalVar(field string) string {
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {

View File

@@ -13,8 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -102,16 +100,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
jobManager := job.NewJobManager(nil, store, peersmanager)
ctx := context.Background()
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -122,11 +113,12 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -160,7 +152,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
"", "", false, false, false)
"", "", false, false)
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"os"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
@@ -36,34 +35,20 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
@@ -175,17 +160,3 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func forceUserspaceFirewall() bool {
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force
}

View File

@@ -7,12 +7,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
// native iptables/nftables is available. This only applies when the WireGuard interface
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
// kernel netfilter rules.
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string

View File

@@ -21,10 +21,6 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
@@ -278,12 +274,6 @@ func (m *aclManager) cleanChains() error {
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
@@ -313,10 +303,6 @@ func (m *aclManager) createDefaultChains() error {
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
@@ -336,13 +322,6 @@ func (m *aclManager) createDefaultChains() error {
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
@@ -364,22 +343,6 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -33,6 +33,7 @@ type Manager struct {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Create iptables firewall manager
@@ -63,9 +64,10 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -201,10 +203,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},

View File

@@ -47,6 +47,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)

View File

@@ -9,9 +9,10 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -22,6 +23,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex

View File

@@ -40,6 +40,7 @@ func getTableName() string {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Manager of iptables firewall
@@ -105,9 +106,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -203,10 +205,12 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@@ -52,6 +52,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface

View File

@@ -8,9 +8,10 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -21,6 +22,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}

View File

@@ -1,37 +0,0 @@
package common
import (
"net/netip"
"sync/atomic"
)
// PacketHook stores a registered hook for a specific IP:port.
type PacketHook struct {
IP netip.Addr
Port uint16
Fn func([]byte) bool
}
// HookMatches checks if a packet's destination matches the hook and invokes it.
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.IP == dstIP && h.Port == dport {
return h.Fn(packetData)
}
return false
}
// SetHook atomically stores a hook, handling nil removal.
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
if hook == nil {
ptr.Store(nil)
return
}
ptr.Store(&PacketHook{
IP: ip,
Port: dPort,
Fn: hook,
})
}

View File

@@ -1,125 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
)
func TestTCPCapEvicts(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "4")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 4, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 10; i++ {
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
}
require.LessOrEqual(t, len(tracker.connections), 4,
"TCP table must not exceed the configured cap")
require.Greater(t, len(tracker.connections), 0,
"some entries must remain after eviction")
// The most recently admitted flow must be present: eviction must make
// room for new entries, not silently drop them.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
"newest TCP flow must be admitted after eviction")
// A pre-cap flow must have been evicted to fit the last one.
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
"oldest TCP flow should have been evicted")
}
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "3")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
// Fill to cap with 3 live connections.
for i := 0; i < 3; i++ {
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
}
require.Len(t, tracker.connections, 3)
// Tombstone one by sending RST through IsValidInbound.
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
// Another live connection forces eviction. The tombstone must go first.
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
require.False(t, tombstonedStillPresent,
"tombstoned entry should be evicted before live entries")
require.LessOrEqual(t, len(tracker.connections), 3)
// Both live pre-cap entries must survive: eviction must prefer the
// tombstone, not just satisfy the size bound by dropping any entry.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
}
func TestUDPCapEvicts(t *testing.T) {
t.Setenv(EnvUDPMaxEntries, "5")
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 5, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 12; i++ {
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
}
require.LessOrEqual(t, len(tracker.connections), 5)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
"newest UDP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
"oldest UDP flow should have been evicted")
}
func TestICMPCapEvicts(t *testing.T) {
t.Setenv(EnvICMPMaxEntries, "3")
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 3, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
for i := 0; i < 8; i++ {
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
}
require.LessOrEqual(t, len(tracker.connections), 3)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
"newest ICMP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
"oldest ICMP flow should have been evicted")
}

View File

@@ -3,61 +3,14 @@ package conntrack
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// evictSampleSize bounds how many map entries we scan per eviction call.
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
// heuristic is good enough for a conntrack table that only overflows under
// abuse.
const evictSampleSize = 8
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
// def on empty or invalid; logs a warning on invalid.
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
v := os.Getenv(name)
if v == "" {
return def
}
d, err := time.ParseDuration(v)
if err != nil {
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
}
if d <= 0 {
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return d
}
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
// invalid, or non-positive. Logs a warning on invalid input.
func envInt(logger *nblog.Logger, name string, def int) int {
v := os.Getenv(name)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
switch {
case err != nil:
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
case n <= 0:
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return n
}
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID

View File

@@ -1,11 +0,0 @@
//go:build !ios && !android
package conntrack
// Default per-tracker entry caps on desktop/server platforms. These mirror
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
const (
DefaultMaxTCPEntries = 65536
DefaultMaxUDPEntries = 16384
DefaultMaxICMPEntries = 2048
)

View File

@@ -1,13 +0,0 @@
//go:build ios || android
package conntrack
// Default per-tracker entry caps on mobile platforms. iOS network extensions
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
// is ~200 B plus map overhead).
const (
DefaultMaxTCPEntries = 4096
DefaultMaxUDPEntries = 2048
DefaultMaxICMPEntries = 512
)

View File

@@ -44,9 +44,6 @@ type ICMPConnTrack struct {
ICMPCode uint8
}
// EnvICMPMaxEntries caps the ICMP conntrack table size.
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
@@ -55,7 +52,6 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -139,7 +135,6 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
flowLogger: flowLogger,
}
@@ -226,9 +221,7 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -247,15 +240,10 @@ func (t *ICMPTracker) track(
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
@@ -298,34 +286,6 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
}
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
func (t *ICMPTracker) evictOneLocked() {
var candKey ICMPConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
@@ -334,10 +294,8 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -38,27 +38,6 @@ const (
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
FinWaitTimeout = 60 * time.Second
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
// holding CloseWait longer than this should bump the env var.
CloseWaitTimeout = 60 * time.Second
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
LastAckTimeout = 30 * time.Second
)
// Env vars to override per-state teardown timeouts. Values parsed by
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
// defaults above with a warning.
const (
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
// (tombstones first) are evicted when the cap is reached.
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
)
// TCPState represents the state of a TCP connection
@@ -154,18 +133,14 @@ func (t *TCPConnTrack) SetTombstone() {
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
finWaitTimeout time.Duration
closeWaitTimeout time.Duration
lastAckTimeout time.Duration
maxEntries int
flowLogger nftypes.FlowLogger
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
@@ -180,17 +155,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
flowLogger: flowLogger,
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
@@ -238,12 +209,6 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
if exists || flags&TCPSyn == 0 {
return
}
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
// create spurious conntrack entries. Not mandated by RFC 9293 but a
// common hardening (Linux netfilter/nftables rejects these too).
if !isValidFlagCombination(flags) {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
@@ -260,65 +225,20 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
// iteration order is randomized), preferring a tombstone. If no tombstone
// found in the sample, evicts the oldest among the sampled entries. O(1)
// worst case — cheap enough to run on every insert at cap during abuse.
func (t *TCPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
if c.IsTombstone() {
delete(t.connections, k)
return
}
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
// TypeEnd is already emitted at the state transition to
// TimeWait and when a connection is tombstoned. Only emit
// here when we're reaping a still-active flow.
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
}
delete(t.connections, candKey)
}
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
@@ -336,19 +256,12 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return false
}
// Reject illegal flag combinations regardless of state. These never belong
// to a legitimate flow and must not advance or tear down state.
if !isValidFlagCombination(flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
}
return false
}
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
}
return false
}
@@ -357,208 +270,116 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return true
}
// updateState updates the TCP connection state based on flags.
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateCounters(packetDir, size)
// Malformed flag combinations must not refresh lastSeen or drive state,
// otherwise spoofed packets keep a dead flow alive past its timeout.
if !isValidFlagCombination(flags) {
return
}
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
currentState := conn.GetState()
if flags&TCPRst != 0 {
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
// cannot apply the RFC 5961 in-window RST check, so we conservatively
// reject RSTs that the spec would accept (TIME-WAIT with in-window
// SEQ, SynSent from same direction as own SYN, etc.).
t.handleRst(key, conn, currentState, packetDir)
return
}
newState := nextState(currentState, conn.Direction, packetDir, flags)
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
return
}
t.onTransition(key, conn, currentState, newState, packetDir)
}
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
// from the SYN direction are ignored; otherwise the flow is tombstoned.
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
// TimeWait exists to absorb late segments; don't let a late RST
// tombstone the entry and break same-4-tuple reuse.
if currentState == TCPStateTimeWait {
return
}
// A RST from the same direction as the SYN cannot be a legitimate
// response and must not tear down a half-open connection.
if currentState == TCPStateSynSent && packetDir == conn.Direction {
return
}
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
return
}
conn.SetTombstone()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
// stateTransition describes one state's transition logic. It receives the
// packet's flags plus whether the packet direction matches the connection's
// origin direction (same=true means same side as the SYN initiator). Return 0
// for no transition.
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
// stateTable maps each state to its transition function. Centralized here so
// nextState stays trivial and each rule is easy to read in isolation.
var stateTable = map[TCPState]stateTransition{
TCPStateNew: transNew,
TCPStateSynSent: transSynSent,
TCPStateSynReceived: transSynReceived,
TCPStateEstablished: transEstablished,
TCPStateFinWait1: transFinWait1,
TCPStateFinWait2: transFinWait2,
TCPStateClosing: transClosing,
TCPStateCloseWait: transCloseWait,
TCPStateLastAck: transLastAck,
}
// nextState returns the target TCP state for the given current state and
// packet, or 0 if the packet does not trigger a transition.
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
fn, ok := stateTable[currentState]
if !ok {
return 0
}
return fn(flags, connDir, packetDir == connDir)
}
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if connDir == nftypes.Egress {
return TCPStateSynSent
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
return TCPStateSynReceived
return
}
return 0
}
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if same {
return TCPStateSynReceived // simultaneous open
var newState TCPState
switch currentState {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
}
return TCPStateEstablished
}
return 0
}
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
return TCPStateEstablished
}
return 0
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction {
newState = TCPStateEstablished
} else {
// Simultaneous open
newState = TCPStateSynReceived
}
}
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin == 0 {
return 0
}
if same {
return TCPStateFinWait1
}
return TCPStateCloseWait
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
if packetDir == conn.Direction {
newState = TCPStateEstablished
}
}
// transFinWait1 handles the active-close peer response. A FIN carrying our
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
if same {
return 0
}
if flags&TCPFin != 0 && flags&TCPAck != 0 {
return TCPStateTimeWait
}
switch {
case flags&TCPFin != 0:
return TCPStateClosing
case flags&TCPAck != 0:
return TCPStateFinWait2
}
return 0
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if packetDir == conn.Direction {
newState = TCPStateFinWait1
} else {
newState = TCPStateCloseWait
}
}
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && !same {
return TCPStateTimeWait
}
return 0
}
case TCPStateFinWait1:
if packetDir != conn.Direction {
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0:
newState = TCPStateClosing
case flags&TCPAck != 0:
newState = TCPStateFinWait2
}
}
// transClosing completes a simultaneous close on the peer's ACK.
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateTimeWait
}
return 0
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
newState = TCPStateTimeWait
}
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && same {
return TCPStateLastAck
}
return 0
}
case TCPStateClosing:
if flags&TCPAck != 0 {
newState = TCPStateTimeWait
}
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateClosed
}
return 0
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
newState = TCPStateLastAck
}
// onTransition handles logging and flow-event emission after a successful
// state transition. TimeWait and Closed are terminal for flow accounting.
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
traceOn := t.logger.Enabled(nblog.LevelTrace)
if traceOn {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
case TCPStateLastAck:
if flags&TCPAck != 0 {
newState = TCPStateClosed
}
}
switch to {
case TCPStateTimeWait:
if traceOn {
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
if traceOn {
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current
// connection state. Caller must have already verified the flag combination is
// legal via isValidFlagCombination.
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
@@ -628,24 +449,15 @@ func (t *TCPTracker) cleanup() {
timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
timeout = t.finWaitTimeout
case TCPStateCloseWait:
timeout = t.closeWaitTimeout
case TCPStateLastAck:
timeout = t.lastAckTimeout
default:
// SynSent / SynReceived / New
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
delete(t.connections, key)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change
if currentState != TCPStateTimeWait {

View File

@@ -1,100 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// RST hygiene tests: the tracker currently closes the flow on any RST that
// matches the 4-tuple, regardless of direction or state. These tests cover
// the minimum checks we want (no SEQ tracking).
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateSynSent, conn.GetState())
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
// cannot be a legitimate response. It must not close the connection.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
require.Equal(t, TCPStateSynSent, conn.GetState(),
"RST in same direction as SYN must not close connection")
require.False(t, conn.IsTombstone())
}
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Drive to TIME-WAIT via active close.
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
// exists to absorb late segments).
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState(),
"RST in TIME-WAIT must not transition state")
require.False(t, conn.IsTombstone(),
"RST in TIME-WAIT must not tombstone the entry")
}
func TestTCPIllegalFlagCombos(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
// Illegal combos must be rejected and must not change state.
combos := []struct {
name string
flags uint8
}{
{"SYN+RST", TCPSyn | TCPRst},
{"FIN+RST", TCPFin | TCPRst},
{"SYN+FIN", TCPSyn | TCPFin},
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
}
for _, c := range combos {
t.Run(c.name, func(t *testing.T) {
before := conn.GetState()
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
require.Equal(t, before, conn.GetState(),
"illegal flag combo must not change state")
require.False(t, conn.IsTombstone())
})
}
}

View File

@@ -1,235 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// These tests exercise cases where the TCP state machine currently advances
// on retransmitted or wrong-direction segments and tears the flow down
// prematurely. They are expected to fail until the direction checks are added.
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer sends FIN -> CloseWait (our app has not yet closed).
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
// sent our FIN yet, so state must remain CloseWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "retransmitted peer FIN must still be accepted")
require.Equal(t, TCPStateCloseWait, conn.GetState(),
"retransmitted peer FIN must not advance CloseWait to LastAck")
// Our app finally closes -> LastAck.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Peer ACK closes.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// We initiate close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Stray retransmit of our own FIN (same direction as originator) must
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait2, conn.GetState(),
"own FIN retransmit must not advance FinWait2 to TimeWait")
// Peer FIN -> TimeWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
}
func TestTCPLastAckDirectionCheck(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateLastAck, conn.GetState())
// Our own ACK retransmit (same direction as originator) must NOT close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState(),
"own ACK retransmit in LastAck must not transition to Closed")
// Peer's ACK -> Closed.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Our own ACK retransmit (same direction as originator) must not advance.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState(),
"own ACK in FinWait1 must not advance to FinWait2")
}
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
// Verify cleanup reaps entries in each teardown state at the configured
// per-state timeout, not at the single handshake timeout.
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
t.Setenv(EnvTCPLastAckTimeout, "30ms")
dstIP := netip.MustParseAddr("100.64.0.2")
dstPort := uint16(80)
// Drives a connection to the target state, forces its lastSeen well
// beyond the configured timeout, runs cleanup, and asserts reaping.
cases := []struct {
name string
// drive takes a fresh tracker and returns the conn key after
// transitioning the flow into the intended teardown state.
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
}{
{
name: "FinWait1",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
},
},
{
name: "FinWait2",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
},
},
{
name: "CloseWait",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
},
},
{
name: "LastAck",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
},
},
}
// Use a unique source port per subtest so nothing aliases.
port := uint16(12345)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
srcIP := netip.MustParseAddr("100.64.0.1")
port++
key, wantState := c.drive(t, tracker, srcIP, port)
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, wantState, conn.GetState())
// Age the entry past the largest per-state timeout.
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
tracker.cleanup()
_, exists := tracker.connections[key]
require.False(t, exists, "%s entry should be reaped", c.name)
})
}
}
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
// teardown states, which some stacks emit during close.
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer FIN -> CloseWait.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
// Peer pushes trailing data + FIN|PSH|ACK (legal).
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
"FIN|PSH|ACK in CloseWait must be accepted")
// Bare ACK keepalive from peer in CloseWait must be accepted.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
"bare ACK in CloseWait must be accepted")
}

View File

@@ -17,9 +17,6 @@ const (
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
// EnvUDPMaxEntries caps the UDP conntrack table size.
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
)
// UDPConnTrack represents a UDP connection state
@@ -37,7 +34,6 @@ type UDPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -55,7 +51,6 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
flowLogger: flowLogger,
}
@@ -122,18 +117,13 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
@@ -161,34 +151,6 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return true
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample: picks the oldest among up to evictSampleSize entries.
func (t *UDPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
@@ -211,10 +173,8 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -142,8 +142,15 @@ type Manager struct {
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[common.PacketHook]
tcpHookOut atomic.Pointer[common.PacketHook]
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -709,9 +716,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
}
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
return false
}
@@ -810,9 +815,7 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
@@ -909,11 +912,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
}
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
return false
}
// filterInbound implements filtering logic for incoming packets.
@@ -935,10 +948,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder
if fragment {
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
}
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
@@ -980,10 +991,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
}
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1033,10 +1042,8 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
}
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
@@ -1053,10 +1060,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if !pass {
proto := getProtocolFromPacket(d)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
}
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1138,9 +1143,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace1("couldn't decode packet, err: %s", err)
}
m.logger.Trace1("couldn't decode packet, err: %s", err)
return false, false
}
@@ -1334,12 +1337,28 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
common.SetHook(&m.udpHookOut, ip, dPort, hook)
if hook == nil {
m.udpHookOut.Store(nil)
return
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(8000), h.Port)
assert.True(t, h.Fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(53), h.Port)
assert.True(t, h.Fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)

View File

@@ -13,7 +13,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -93,10 +92,8 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
return conn, nil
}
@@ -119,10 +116,8 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
}
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
@@ -203,17 +198,13 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
}
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
txBytes := f.synthesizeEchoReply(id, icmpData)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
}
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}

View File

@@ -1,9 +1,12 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"sync"
"github.com/google/uuid"
@@ -13,9 +16,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/util/netrelay"
)
// handleTCP is called by the TCP forwarder for new connections.
@@ -37,9 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
}
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
return
}
@@ -62,22 +61,64 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
}
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
// netrelay.Relay copies bidirectionally with proper half-close propagation
// and fully closes both conns before returning.
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
Logger: f.logger,
})
// Close the netstack endpoint after both conns are drained.
ep.Close()
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
@@ -86,9 +127,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value()
}
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
}
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}

View File

@@ -125,9 +125,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
@@ -146,9 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
}
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
return true
}
@@ -210,9 +206,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.Unlock()
success = true
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
}
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
return true
@@ -271,9 +265,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value()
}
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
}
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)

View File

@@ -1,90 +0,0 @@
package uspfilter
import (
"encoding/binary"
"net/netip"
"sync/atomic"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9
ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17
ipProtoTCP = 6
ipv4FragOffMask = 0x1fff
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
dstPortOffset = 2
)
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
// It is installed on the WireGuard interface when the userspace bind is active
// but a full firewall filter (Manager) is not needed because a native kernel
// firewall (nftables/iptables) handles packet filtering.
type HooksFilter struct {
udpHook atomic.Pointer[common.PacketHook]
tcpHook atomic.Pointer[common.PacketHook]
}
var _ device.PacketFilter = (*HooksFilter)(nil)
// FilterOutbound checks outbound packets for DNS hook matches.
// Only IPv4 packets matching the registered hook IP:port are intercepted.
// IPv6 and non-IP packets pass through unconditionally.
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
if len(packetData) < ipv4HeaderMinLen {
return false
}
// Only process IPv4 packets, let everything else pass through.
if packetData[0]>>4 != 4 {
return false
}
ihl := int(packetData[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
return false
}
// Skip non-first fragments: they don't carry L4 headers.
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
if flagsAndOffset&ipv4FragOffMask != 0 {
return false
}
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
if !ok {
return false
}
proto := packetData[ipv4ProtoOffset]
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
switch proto {
case ipProtoUDP:
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
case ipProtoTCP:
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
default:
return false
}
}
// FilterInbound allows all inbound packets (native firewall handles filtering).
func (f *HooksFilter) FilterInbound([]byte, int) bool {
return false
}
// SetUDPPacketHook registers the UDP packet hook.
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.udpHook, ip, dPort, hook)
}
// SetTCPPacketHook registers the TCP packet hook.
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.tcpHook, ip, dPort, hook)
}

View File

@@ -53,17 +53,16 @@ var levelStrings = map[Level]string{
}
type logMessage struct {
level Level
argCount uint8
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
level Level
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -108,13 +107,6 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
// Enabled reports whether the given level is currently logged. Callers on the
// hot path should guard log sites with this to avoid boxing arguments into
// any when the level is off.
func (l *Logger) Enabled(level Level) bool {
return l.level.Load() >= uint32(level)
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -163,7 +155,7 @@ func (l *Logger) Trace(format string) {
func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
default:
}
}
@@ -172,16 +164,7 @@ func (l *Logger) Error1(format string, arg1 any) {
func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -190,7 +173,7 @@ func (l *Logger) Warn2(format string, arg1, arg2 any) {
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -199,7 +182,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -208,7 +191,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
default:
}
}
@@ -217,7 +200,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -226,59 +209,16 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
// to avoid allocating an args slice on the fast path when the arg count is
// known (0-3). Args beyond 3 land on the general variadic path; callers on
// the hot path should prefer DebugN for known counts.
func (l *Logger) Debugf(format string, args ...any) {
if l.level.Load() < uint32(LevelDebug) {
return
}
switch len(args) {
case 0:
l.Debug(format)
case 1:
l.Debug1(format, args[0])
case 2:
l.Debug2(format, args[0], args[1])
case 3:
l.Debug3(format, args[0], args[1], args[2])
default:
l.sendVariadic(LevelDebug, format, args)
}
}
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
// beyond the 8-arg slot limit are dropped so callers don't produce silently
// empty log lines via uint8 wraparound in argCount.
func (l *Logger) sendVariadic(level Level, format string, args []any) {
const maxArgs = 8
n := len(args)
if n > maxArgs {
n = maxArgs
}
msg := logMessage{level: level, argCount: uint8(n), format: format}
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
for i := 0; i < n; i++ {
*slots[i] = args[i]
}
select {
case l.msgChannel <- msg:
default:
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
default:
}
}
@@ -287,7 +227,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -296,7 +236,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -305,7 +245,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -314,7 +254,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default:
}
}
@@ -323,7 +263,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default:
}
}
@@ -333,7 +273,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
@@ -346,8 +286,35 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ')
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
}
}
}
var formatted string
switch msg.argCount {
switch argCount {
case 0:
formatted = msg.format
case 1:

View File

@@ -11,7 +11,6 @@ import (
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
@@ -243,15 +242,11 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet destination: %v", err)
}
m.logger.Error1("failed to rewrite packet destination: %v", err)
return false
}
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
}
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
@@ -269,15 +264,11 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet source: %v", err)
}
m.logger.Error1("failed to rewrite packet source: %v", err)
return false
}
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
@@ -530,9 +521,7 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite port: %v", err)
}
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort

View File

@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
// Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

View File

@@ -1,113 +0,0 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
}
if u.address.Network.Contains(a) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}

View File

@@ -19,9 +19,6 @@ import (
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
@@ -138,7 +135,6 @@ func TestDefaultManager(t *testing.T) {
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
@@ -198,7 +194,6 @@ func TestDefaultManagerStateless(t *testing.T) {
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
@@ -263,7 +258,6 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -345,7 +339,6 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
cacheDir string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
}
@@ -340,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err)
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,6 +16,7 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -232,7 +234,6 @@ type BundleGenerator struct {
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -255,7 +256,6 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
@@ -275,7 +275,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -288,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return "", fmt.Errorf("create zip file: %w", err)
}
@@ -374,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if err := g.addPlatformLog(); err != nil {
log.Errorf("failed to add logs to debug bundle: %v", err)
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {

View File

@@ -1,41 +0,0 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -1,25 +0,0 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -140,7 +140,6 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config
LogPath string
TempDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -1096,7 +1095,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)

View File

@@ -55,7 +55,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -1635,12 +1634,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -1662,7 +1656,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}

View File

@@ -22,8 +22,4 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager
FileDescriptor int32
StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
}

View File

@@ -7,9 +7,7 @@ import (
"fmt"
"net/netip"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
nfct "github.com/ti-mo/conntrack"
@@ -19,64 +17,31 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
defaultChannelSize = 100
reconnectInitInterval = 5 * time.Second
reconnectMaxInterval = 5 * time.Minute
reconnectRandomization = 0.5
)
// listener abstracts a netlink conntrack connection for testability.
type listener interface {
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
Close() error
}
const defaultChannelSize = 100
// ConnTrack manages kernel-based conntrack events
type ConnTrack struct {
flowLogger nftypes.FlowLogger
iface nftypes.IFaceMapper
conn listener
conn *nfct.Conn
mux sync.Mutex
dial func() (listener, error)
instanceID uuid.UUID
started bool
done chan struct{}
sysctlModified bool
}
// DialFunc is a constructor for netlink conntrack connections.
type DialFunc func() (listener, error)
// Option configures a ConnTrack instance.
type Option func(*ConnTrack)
// WithDialer overrides the default netlink dialer, primarily for testing.
func WithDialer(dial DialFunc) Option {
return func(c *ConnTrack) {
c.dial = dial
}
}
func defaultDial() (listener, error) {
return nfct.Dial(nil)
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
ct := &ConnTrack{
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
return &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
dial: defaultDial,
started: false,
done: make(chan struct{}, 1),
}
for _, opt := range opts {
opt(ct)
}
return ct
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
@@ -94,9 +59,8 @@ func (c *ConnTrack) Start(enableCounters bool) error {
c.EnableAccounting()
}
conn, err := c.dial()
conn, err := nfct.Dial(nil)
if err != nil {
c.RestoreAccounting()
return fmt.Errorf("dial conntrack: %w", err)
}
c.conn = conn
@@ -112,16 +76,9 @@ func (c *ConnTrack) Start(enableCounters bool) error {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
c.RestoreAccounting()
return fmt.Errorf("start conntrack listener: %w", err)
}
// Drain any stale stop signal from a previous cycle.
select {
case <-c.done:
default:
}
c.started = true
go c.receiverRoutine(events, errChan)
@@ -135,98 +92,17 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error)
case event := <-events:
c.handleEvent(event)
case err := <-errChan:
if events, errChan = c.handleListenerError(err); events == nil {
return
log.Errorf("Error from conntrack event listener: %v", err)
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
return
case <-c.done:
return
}
}
}
// handleListenerError closes the failed connection and attempts to reconnect.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
log.Warnf("conntrack event listener failed: %v", err)
c.closeConn()
return c.reconnect()
}
func (c *ConnTrack) closeConn() {
c.mux.Lock()
defer c.mux.Unlock()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Debugf("close conntrack connection: %v", err)
}
c.conn = nil
}
}
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
bo := &backoff.ExponentialBackOff{
InitialInterval: reconnectInitInterval,
RandomizationFactor: reconnectRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: reconnectMaxInterval,
MaxElapsedTime: 0, // retry indefinitely
Clock: backoff.SystemClock,
}
bo.Reset()
for {
delay := bo.NextBackOff()
log.Infof("reconnecting conntrack listener in %s", delay)
select {
case <-c.done:
c.mux.Lock()
c.started = false
c.mux.Unlock()
return nil, nil
case <-time.After(delay):
}
conn, err := c.dial()
if err != nil {
log.Warnf("reconnect conntrack dial: %v", err)
continue
}
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
log.Warnf("reconnect conntrack listen: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
continue
}
c.mux.Lock()
if !c.started {
// Stop() ran while we were reconnecting.
c.mux.Unlock()
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
return nil, nil
}
c.conn = conn
c.mux.Unlock()
log.Infof("conntrack listener reconnected successfully")
return events, errChan
}
}
// Stop stops the connection tracking. This method is idempotent.
func (c *ConnTrack) Stop() {
c.mux.Lock()
@@ -260,27 +136,23 @@ func (c *ConnTrack) Close() error {
c.mux.Lock()
defer c.mux.Unlock()
if !c.started {
return nil
if c.started {
select {
case c.done <- struct{}{}:
default:
}
}
select {
case c.done <- struct{}{}:
default:
}
c.started = false
var closeErr error
if c.conn != nil {
closeErr = c.conn.Close()
err := c.conn.Close()
c.conn = nil
}
c.started = false
c.RestoreAccounting()
c.RestoreAccounting()
if closeErr != nil {
return fmt.Errorf("close conntrack: %w", closeErr)
if err != nil {
return fmt.Errorf("close conntrack: %w", err)
}
}
return nil

View File

@@ -1,224 +0,0 @@
//go:build linux && !android
package conntrack
import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nfct "github.com/ti-mo/conntrack"
"github.com/ti-mo/netfilter"
)
type mockListener struct {
errChan chan error
closed atomic.Bool
closedCh chan struct{}
}
func newMockListener() *mockListener {
return &mockListener{
errChan: make(chan error, 1),
closedCh: make(chan struct{}),
}
}
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
return m.errChan, nil
}
func (m *mockListener) Close() error {
if m.closed.CompareAndSwap(false, true) {
close(m.closedCh)
}
return nil
}
func TestReconnectAfterError(t *testing.T) {
first := newMockListener()
second := newMockListener()
third := newMockListener()
listeners := []*mockListener{first, second, third}
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := int(callCount.Add(1)) - 1
return listeners[n], nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Inject an error on the first listener.
first.errChan <- assert.AnError
// Wait for reconnect to complete.
require.Eventually(t, func() bool {
return callCount.Load() >= 2
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
// The first connection must have been closed.
select {
case <-first.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("first connection was not closed")
}
// Verify the receiver is still running by injecting and handling a second error.
second.errChan <- assert.AnError
require.Eventually(t, func() bool {
return callCount.Load() >= 3
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
ct.Stop()
}
func TestStopDuringReconnectBackoff(t *testing.T) {
mock := newMockListener()
ct := New(nil, nil, WithDialer(func() (listener, error) {
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Trigger an error so the receiver enters reconnect.
mock.errChan <- assert.AnError
// Wait for the error handler to close the old listener before calling Stop.
select {
case <-mock.closedCh:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for reconnect to start")
}
// Stop while reconnecting.
ct.Stop()
ct.mux.Lock()
assert.False(t, ct.started, "started should be false after Stop")
assert.Nil(t, ct.conn, "conn should be nil after Stop")
ct.mux.Unlock()
}
func TestStopRaceWithReconnectDial(t *testing.T) {
first := newMockListener()
dialStarted := make(chan struct{})
dialProceed := make(chan struct{})
second := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
// Second dial: signal that we're in progress, wait for test to call Stop.
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Trigger error to enter reconnect.
first.errChan <- assert.AnError
// Wait for reconnect's second dial to begin.
select {
case <-dialStarted:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for reconnect dial")
}
// Stop while dial is in progress (conn is nil at this point).
ct.Stop()
// Let the dial complete. reconnect should detect started==false and close the new conn.
close(dialProceed)
// The second connection should be closed (not leaked).
select {
case <-second.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("second connection was leaked after Stop")
}
ct.mux.Lock()
assert.False(t, ct.started)
assert.Nil(t, ct.conn)
ct.mux.Unlock()
}
func TestCloseRaceWithReconnectDial(t *testing.T) {
first := newMockListener()
dialStarted := make(chan struct{})
dialProceed := make(chan struct{})
second := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
first.errChan <- assert.AnError
select {
case <-dialStarted:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for reconnect dial")
}
// Close while dial is in progress (conn is nil).
require.NoError(t, ct.Close())
close(dialProceed)
// The second connection should be closed (not leaked).
select {
case <-second.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("second connection was leaked after Close")
}
ct.mux.Lock()
assert.False(t, ct.started)
assert.Nil(t, ct.conn)
ct.mux.Unlock()
}
func TestStartIsIdempotent(t *testing.T) {
mock := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
callCount.Add(1)
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Second Start should be a no-op.
err = ct.Start(false)
require.NoError(t, err)
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
ct.Stop()
}

View File

@@ -8,27 +8,18 @@ import (
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
)
func isDisabledByEnv() bool {
return parseBoolEnv(envDisableNATMapper)
}
func isHealthCheckDisabled() bool {
return parseBoolEnv(envDisablePCPHealthCheck)
}
func parseBoolEnv(key string) bool {
val := os.Getenv(key)
val := os.Getenv(envDisableNATMapper)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", key, err)
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
return false
}
return disabled

View File

@@ -12,15 +12,12 @@ import (
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
const (
defaultMappingTTL = 2 * time.Hour
healthCheckInterval = 1 * time.Minute
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
@@ -157,7 +154,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := discoverGateway(discoverCtx)
gateway, err := nat.DiscoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
@@ -192,6 +189,7 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
@@ -210,87 +208,27 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation
// but still run health checks for PCP gateways.
m.permanentLeaseLoop(ctx, gateway)
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
return
}
renewTicker := time.NewTicker(ttl / 2)
healthTicker := time.NewTicker(healthCheckInterval)
defer renewTicker.Stop()
defer healthTicker.Stop()
ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-renewTicker.C:
case <-ticker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
case <-healthTicker.C:
if m.checkHealthAndRecreate(ctx, gateway) {
renewTicker.Reset(ttl / 2)
}
}
}
}
func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
healthTicker := time.NewTicker(healthCheckInterval)
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-healthTicker.C:
m.checkHealthAndRecreate(ctx, gateway)
}
}
}
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
if isHealthCheckDisabled() {
return false
}
m.mappingLock.Lock()
hasMapping := m.mapping != nil
m.mappingLock.Unlock()
if !hasMapping {
return false
}
pcpNAT, ok := gateway.(*pcp.NAT)
if !ok {
return false
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
if err != nil {
log.Debugf("PCP health check failed: %v", err)
return false
}
if serverRestarted {
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
if err := m.renewMapping(ctx, gateway); err != nil {
log.Errorf("failed to recreate port mapping after server restart: %v", err)
return false
}
return true
}
return false
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

View File

@@ -1,408 +0,0 @@
package pcp
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultTimeout = 3 * time.Second
responseBufferSize = 128
// RFC 6887 Section 8.1.1 retry timing
initialRetryDelay = 3 * time.Second
maxRetryDelay = 1024 * time.Second
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
)
// Client is a PCP protocol client.
// All methods are safe for concurrent use.
type Client struct {
gateway netip.Addr
timeout time.Duration
mu sync.Mutex
// localIP caches the resolved local IP address.
localIP netip.Addr
// lastEpoch is the last observed server epoch value.
lastEpoch uint32
// epochTime tracks when lastEpoch was received for state loss detection.
epochTime time.Time
// externalIP caches the external IP from the last successful MAP response.
externalIP netip.Addr
// epochStateLost is set when epoch indicates server restart.
epochStateLost bool
}
// NewClient creates a new PCP client for the gateway at the given IP.
func NewClient(gateway net.IP) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: defaultTimeout,
}
}
// NewClientWithTimeout creates a new PCP client with a custom timeout.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: timeout,
}
}
// SetLocalIP sets the local IP address to use in PCP requests.
func (c *Client) SetLocalIP(ip net.IP) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Debugf("invalid local IP: %v", ip)
}
c.mu.Lock()
c.localIP = addr.Unmap()
c.mu.Unlock()
}
// Gateway returns the gateway IP address.
func (c *Client) Gateway() net.IP {
return c.gateway.AsSlice()
}
// Announce sends a PCP ANNOUNCE request to discover PCP support.
// Returns the server's epoch time on success.
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
localIP, err := c.getLocalIP()
if err != nil {
return 0, fmt.Errorf("get local IP: %w", err)
}
req := buildAnnounceRequest(localIP)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return 0, fmt.Errorf("send announce: %w", err)
}
parsed, err := parseResponse(resp)
if err != nil {
return 0, fmt.Errorf("parse announce response: %w", err)
}
if parsed.ResultCode != ResultSuccess {
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
}
c.mu.Lock()
if c.updateEpochLocked(parsed.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.mu.Unlock()
return parsed.Epoch, nil
}
// AddPortMapping requests a port mapping from the PCP server.
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
}
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
// Use lifetime <= 0 to delete a mapping.
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
var extIP netip.Addr
if suggestedExtIP != nil {
var ok bool
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
if !ok {
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
}
extIP = extIP.Unmap()
}
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
}
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
localIP, err := c.getLocalIP()
if err != nil {
return nil, fmt.Errorf("get local IP: %w", err)
}
proto, err := protocolNumber(protocol)
if err != nil {
return nil, fmt.Errorf("parse protocol: %w", err)
}
var nonce [12]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("generate nonce: %w", err)
}
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
// default for positive durations that round to 0 seconds.
var lifetimeSec uint32
if lifetime > 0 {
lifetimeSec = uint32(lifetime.Seconds())
if lifetimeSec == 0 {
lifetimeSec = DefaultLifetime
}
}
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("send map request: %w", err)
}
mapResp, err := parseMapResponse(resp)
if err != nil {
return nil, fmt.Errorf("parse map response: %w", err)
}
if mapResp.Nonce != nonce {
return nil, fmt.Errorf("nonce mismatch in response")
}
if mapResp.Protocol != proto {
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
}
if mapResp.InternalPort != uint16(internalPort) {
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
}
if mapResp.ResultCode != ResultSuccess {
return nil, &Error{
Code: mapResp.ResultCode,
Message: ResultCodeString(mapResp.ResultCode),
}
}
c.mu.Lock()
if c.updateEpochLocked(mapResp.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.cacheExternalIPLocked(mapResp.ExternalIP)
c.mu.Unlock()
return mapResp, nil
}
// DeletePortMapping removes a port mapping by requesting zero lifetime.
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
var pcpErr *Error
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
return nil
}
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// GetExternalAddress returns the external IP address.
// First checks for a cached value from previous MAP responses.
// If not cached, creates a short-lived mapping to discover the external IP.
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
c.mu.Lock()
if c.externalIP.IsValid() {
ip := c.externalIP.AsSlice()
c.mu.Unlock()
return ip, nil
}
c.mu.Unlock()
// Use an ephemeral port in the dynamic range (49152-65535).
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
// Use minimal lifetime (1 second) for discovery.
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
if err != nil {
return nil, fmt.Errorf("create temporary mapping: %w", err)
}
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
log.Debugf("cleanup temporary PCP mapping: %v", err)
}
return resp.ExternalIP.AsSlice(), nil
}
// LastEpoch returns the last observed server epoch value.
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
func (c *Client) LastEpoch() uint32 {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastEpoch
}
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
func (c *Client) EpochStateLost() bool {
c.mu.Lock()
defer c.mu.Unlock()
lost := c.epochStateLost
c.epochStateLost = false
return lost
}
// updateEpoch updates the epoch tracking and detects potential state loss.
// Returns true if state loss was detected (server likely restarted).
// Caller must hold c.mu.
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
now := time.Now()
stateLost := false
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
// client_delta = time since last response
// server_delta = epoch change since last response
// Invalid if: client_delta+2 < server_delta - server_delta/16
// OR: server_delta+2 < client_delta - client_delta/16
// The +2 handles quantization, /16 (6.25%) handles clock drift.
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
serverDelta := newEpoch - c.lastEpoch
// Check for epoch going backwards or jumping unexpectedly.
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
if clientDelta+2 < serverDelta-(serverDelta/16) ||
serverDelta+2 < clientDelta-(clientDelta/16) {
stateLost = true
c.epochStateLost = true
}
}
c.lastEpoch = newEpoch
c.epochTime = now
return stateLost
}
// cacheExternalIP stores the external IP from a successful MAP response.
// Caller must hold c.mu.
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
if ip.IsValid() && !ip.IsUnspecified() {
c.externalIP = ip
}
}
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
var lastErr error
delay := initialRetryDelay
for range maxRetries {
resp, err := c.sendOnce(ctx, addr, req)
if err == nil {
return resp, nil
}
lastErr = err
if ctx.Err() != nil {
return nil, ctx.Err()
}
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
// RAND is random between -0.1 and +0.1
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelayWithJitter(delay)):
}
delay = min(delay*2, maxRetryDelay)
}
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
}
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
func retryDelayWithJitter(d time.Duration) time.Duration {
var b [1]byte
_, _ = rand.Read(b[:])
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
jitter := (float64(b[0])/255.0)*0.2 - 0.1
return time.Duration(float64(d) * (1 + jitter))
}
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("close UDP connection: %v", err)
}
}()
timeout := c.timeout
if deadline, ok := ctx.Deadline(); ok {
if remaining := time.Until(deadline); remaining < timeout {
timeout = remaining
}
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
if _, err := conn.WriteToUDP(req, addr); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
resp := make([]byte, responseBufferSize)
n, from, err := conn.ReadFromUDP(resp)
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
// RFC 6887 §8.3: Validate response came from expected PCP server.
if !from.IP.Equal(addr.IP) {
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
}
return resp[:n], nil
}
func (c *Client) getLocalIP() (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.localIP.IsValid() {
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
}
return c.localIP, nil
}
func protocolNumber(protocol string) (uint8, error) {
switch protocol {
case "udp", "UDP":
return ProtoUDP, nil
case "tcp", "TCP":
return ProtoTCP, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// Error represents a PCP error response.
type Error struct {
Code uint8
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
}

View File

@@ -1,187 +0,0 @@
package pcp
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddrConversion(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4", netip.MustParseAddr("192.168.1.100")},
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
{"IPv6", netip.MustParseAddr("2001:db8::1")},
{"IPv6 loopback", netip.MustParseAddr("::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b16 := addrTo16(tt.addr)
recovered := addrFrom16(b16)
assert.Equal(t, tt.addr, recovered, "address should round-trip")
})
}
}
func TestBuildAnnounceRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
req := buildAnnounceRequest(clientIP)
require.Len(t, req, headerSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
// Check client IP is properly encoded as IPv4-mapped IPv6
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
assert.Equal(t, byte(192), req[20], "IP octet 1")
assert.Equal(t, byte(168), req[21], "IP octet 2")
assert.Equal(t, byte(1), req[22], "IP octet 3")
assert.Equal(t, byte(100), req[23], "IP octet 4")
}
func TestBuildMapRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
require.Len(t, req, mapRequestSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpMap), req[1], "opcode")
// Lifetime at bytes 4-7
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
// Nonce at bytes 24-35
assert.Equal(t, nonce[:], req[24:36], "nonce")
// Protocol at byte 36
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
// Internal port at bytes 40-41
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
// External port at bytes 42-43
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
}
func TestParseResponse(t *testing.T) {
// Construct a valid ANNOUNCE response
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce | OpReply
// Result code = 0 (success)
// Lifetime = 0
// Epoch = 12345
resp[8] = 0
resp[9] = 0
resp[10] = 0x30
resp[11] = 0x39
parsed, err := parseResponse(resp)
require.NoError(t, err)
assert.Equal(t, uint8(Version), parsed.Version)
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
assert.Equal(t, uint32(12345), parsed.Epoch)
}
func TestParseResponseErrors(t *testing.T) {
t.Run("too short", func(t *testing.T) {
_, err := parseResponse([]byte{1, 2, 3})
assert.Error(t, err)
})
t.Run("wrong version", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = 1 // Wrong version
resp[1] = OpReply
_, err := parseResponse(resp)
assert.Error(t, err)
})
t.Run("missing reply bit", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce // Missing OpReply bit
_, err := parseResponse(resp)
assert.Error(t, err)
})
}
func TestResultCodeString(t *testing.T) {
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
}
func TestProtocolNumber(t *testing.T) {
proto, err := protocolNumber("udp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
proto, err = protocolNumber("tcp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoTCP), proto)
proto, err = protocolNumber("UDP")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
_, err = protocolNumber("icmp")
assert.Error(t, err)
}
func TestClientCreation(t *testing.T) {
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
client := NewClient(gateway)
assert.Equal(t, net.IP(gateway), client.Gateway())
assert.Equal(t, defaultTimeout, client.timeout)
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
}
func TestNATType(t *testing.T) {
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
assert.Equal(t, "PCP", n.Type())
}
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
func TestClientIntegration(t *testing.T) {
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
client := NewClient(gateway)
client.SetLocalIP(localIP)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Test ANNOUNCE
epoch, err := client.Announce(ctx)
require.NoError(t, err)
t.Logf("Server epoch: %d", epoch)
// Test MAP
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
require.NoError(t, err)
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
// Cleanup
err = client.DeletePortMapping(ctx, "udp", 51820)
require.NoError(t, err)
}

View File

@@ -1,209 +0,0 @@
package pcp
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/libp2p/go-nat"
"github.com/libp2p/go-netroute"
)
var _ nat.NAT = (*NAT)(nil)
// NAT implements the go-nat NAT interface using PCP.
// Supports dual-stack (IPv4 and IPv6) when available.
// All methods are safe for concurrent use.
//
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
// and needs to be recreated with the new address.
type NAT struct {
client *Client
mu sync.RWMutex
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
client6 *Client
// localIP6 caches the local IPv6 address used for PCP requests.
localIP6 netip.Addr
}
// NewNAT creates a new NAT instance backed by PCP.
func NewNAT(gateway, localIP net.IP) *NAT {
client := NewClient(gateway)
client.SetLocalIP(localIP)
return &NAT{
client: client,
}
}
// Type returns "PCP" as the NAT type.
func (n *NAT) Type() string {
return "PCP"
}
// GetDeviceAddress returns the gateway IP address.
func (n *NAT) GetDeviceAddress() (net.IP, error) {
return n.client.Gateway(), nil
}
// GetExternalAddress returns the external IP address.
func (n *NAT) GetExternalAddress() (net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return n.client.GetExternalAddress(ctx)
}
// GetInternalAddress returns the local IP address used to communicate with the gateway.
func (n *NAT) GetInternalAddress() (net.IP, error) {
addr, err := n.client.getLocalIP()
if err != nil {
return nil, err
}
return addr.AsSlice(), nil
}
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
if err != nil {
return 0, fmt.Errorf("add mapping: %w", err)
}
n.mu.RLock()
client6 := n.client6
localIP6 := n.localIP6
n.mu.RUnlock()
if client6 == nil {
return int(resp.ExternalPort), nil
}
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
return int(resp.ExternalPort), nil
}
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
return int(resp.ExternalPort), nil
}
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
n.mu.RLock()
client6 := n.client6
n.mu.RUnlock()
if client6 != nil {
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
}
}
if err != nil {
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
epoch, err = n.client.Announce(ctx)
if err != nil {
return 0, false, fmt.Errorf("announce: %w", err)
}
return epoch, n.client.EpochStateLost(), nil
}
// DiscoverPCP attempts to discover a PCP-capable gateway.
// Returns a NAT interface if PCP is supported, or an error otherwise.
// Discovers both IPv4 and IPv6 gateways when available.
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
gateway, localIP, err := getDefaultGateway()
if err != nil {
return nil, fmt.Errorf("get default gateway: %w", err)
}
client := NewClient(gateway)
client.SetLocalIP(localIP)
if _, err := client.Announce(ctx); err != nil {
return nil, fmt.Errorf("PCP announce: %w", err)
}
result := &NAT{client: client}
discoverIPv6(ctx, result)
return result, nil
}
func discoverIPv6(ctx context.Context, result *NAT) {
gateway6, localIP6, err := getDefaultGateway6()
if err != nil {
log.Debugf("IPv6 gateway discovery failed: %v", err)
return
}
client6 := NewClient(gateway6)
client6.SetLocalIP(localIP6)
if _, err := client6.Announce(ctx); err != nil {
log.Debugf("PCP IPv6 announce failed: %v", err)
return
}
addr, ok := netip.AddrFromSlice(localIP6)
if !ok {
log.Debugf("invalid IPv6 local IP: %v", localIP6)
return
}
result.mu.Lock()
result.client6 = client6
result.localIP6 = addr
result.mu.Unlock()
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
}
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv4zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv6zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}

View File

@@ -1,225 +0,0 @@
// Package pcp implements the Port Control Protocol (RFC 6887).
//
// # Implemented Features
//
// - ANNOUNCE opcode: Discovers PCP server support
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
// - Nonce validation: Prevents response spoofing
// - Epoch tracking: Detects server restarts per Section 8.5
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
//
// # Not Implemented
//
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
// - THIRD_PARTY option: For managing mappings on behalf of other devices
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
// - FILTER option: To restrict remote peer addresses
//
// These optional features are omitted because the primary use case is simple
// port forwarding for WireGuard, which only requires MAP with default behavior.
package pcp
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
// Version is the PCP protocol version (RFC 6887).
Version = 2
// Port is the standard PCP server port.
Port = 5351
// DefaultLifetime is the default requested mapping lifetime in seconds.
DefaultLifetime = 7200 // 2 hours
// Header sizes
headerSize = 24
mapPayloadSize = 36
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
)
// Opcodes
const (
OpAnnounce = 0
OpMap = 1
OpPeer = 2
OpReply = 0x80 // OR'd with opcode in responses
)
// Protocol numbers for MAP requests
const (
ProtoUDP = 17
ProtoTCP = 6
)
// Result codes (RFC 6887 Section 7.4)
const (
ResultSuccess = 0
ResultUnsuppVersion = 1
ResultNotAuthorized = 2
ResultMalformedRequest = 3
ResultUnsuppOpcode = 4
ResultUnsuppOption = 5
ResultMalformedOption = 6
ResultNetworkFailure = 7
ResultNoResources = 8
ResultUnsuppProtocol = 9
ResultUserExQuota = 10
ResultCannotProvideExt = 11
ResultAddressMismatch = 12
ResultExcessiveRemotePeers = 13
)
// ResultCodeString returns a human-readable string for a result code.
func ResultCodeString(code uint8) string {
switch code {
case ResultSuccess:
return "SUCCESS"
case ResultUnsuppVersion:
return "UNSUPP_VERSION"
case ResultNotAuthorized:
return "NOT_AUTHORIZED"
case ResultMalformedRequest:
return "MALFORMED_REQUEST"
case ResultUnsuppOpcode:
return "UNSUPP_OPCODE"
case ResultUnsuppOption:
return "UNSUPP_OPTION"
case ResultMalformedOption:
return "MALFORMED_OPTION"
case ResultNetworkFailure:
return "NETWORK_FAILURE"
case ResultNoResources:
return "NO_RESOURCES"
case ResultUnsuppProtocol:
return "UNSUPP_PROTOCOL"
case ResultUserExQuota:
return "USER_EX_QUOTA"
case ResultCannotProvideExt:
return "CANNOT_PROVIDE_EXTERNAL"
case ResultAddressMismatch:
return "ADDRESS_MISMATCH"
case ResultExcessiveRemotePeers:
return "EXCESSIVE_REMOTE_PEERS"
default:
return fmt.Sprintf("UNKNOWN(%d)", code)
}
}
// Response represents a parsed PCP response header.
type Response struct {
Version uint8
Opcode uint8
ResultCode uint8
Lifetime uint32
Epoch uint32
}
// MapResponse contains the full response to a MAP request.
type MapResponse struct {
Response
Nonce [12]byte
Protocol uint8
InternalPort uint16
ExternalPort uint16
ExternalIP netip.Addr
}
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
func addrTo16(addr netip.Addr) [16]byte {
if addr.Is4() {
return netip.AddrFrom4(addr.As4()).As16()
}
return addr.As16()
}
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
func addrFrom16(b [16]byte) netip.Addr {
return netip.AddrFrom16(b).Unmap()
}
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
func buildAnnounceRequest(clientIP netip.Addr) []byte {
req := make([]byte, headerSize)
req[0] = Version
req[1] = OpAnnounce
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
return req
}
// buildMapRequest creates a PCP MAP request packet.
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
req := make([]byte, mapRequestSize)
// Header
req[0] = Version
req[1] = OpMap
binary.BigEndian.PutUint32(req[4:8], lifetime)
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
// MAP payload
copy(req[24:36], nonce[:])
req[36] = protocol
binary.BigEndian.PutUint16(req[40:42], internalPort)
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
if suggestedExtIP.IsValid() {
extMapped := addrTo16(suggestedExtIP)
copy(req[44:60], extMapped[:])
}
return req
}
// parseResponse parses the common PCP response header.
func parseResponse(data []byte) (*Response, error) {
if len(data) < headerSize {
return nil, fmt.Errorf("response too short: %d bytes", len(data))
}
resp := &Response{
Version: data[0],
Opcode: data[1],
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
Lifetime: binary.BigEndian.Uint32(data[4:8]),
Epoch: binary.BigEndian.Uint32(data[8:12]),
}
if resp.Version != Version {
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
}
if resp.Opcode&OpReply == 0 {
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
}
return resp, nil
}
// parseMapResponse parses a complete MAP response.
func parseMapResponse(data []byte) (*MapResponse, error) {
if len(data) < mapRequestSize {
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
}
resp, err := parseResponse(data)
if err != nil {
return nil, fmt.Errorf("parse header: %w", err)
}
mapResp := &MapResponse{
Response: *resp,
Protocol: data[36],
InternalPort: binary.BigEndian.Uint16(data[40:42]),
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
ExternalIP: addrFrom16([16]byte(data[44:60])),
}
copy(mapResp.Nonce[:], data[24:36])
return mapResp, nil
}

View File

@@ -1,63 +0,0 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
// discoverGateway is the function used for NAT gateway discovery.
// It can be replaced in tests to avoid real network operations.
// Tries PCP first, then falls back to NAT-PMP/UPnP.
var discoverGateway = defaultDiscoverGateway
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
pcpGateway, err := pcp.DiscoverPCP(ctx)
if err == nil {
return pcpGateway, nil
}
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
return nat.DiscoverGateway(ctx)
}
// State is persisted only for crash recovery cleanup
type State struct {
InternalPort uint16 `json:"internal_port,omitempty"`
Protocol string `json:"protocol,omitempty"`
}
func (s *State) Name() string {
return "port_forward_state"
}
// Cleanup implements statemanager.CleanableState for crash recovery
func (s *State) Cleanup() error {
if s.InternalPort == 0 {
return nil
}
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
gateway, err := discoverGateway(ctx)
if err != nil {
// Discovery failure is not an error - gateway may not exist
log.Debugf("cleanup: no gateway found: %v", err)
return nil
}
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
return fmt.Errorf("delete port mapping: %w", err)
}
return nil
}

View File

@@ -168,7 +168,6 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
m.notifier.SetFakeIPRoute(fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)

View File

@@ -16,7 +16,6 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoute *route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -32,17 +31,13 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
n.fakeIPRoute = r
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
@@ -74,9 +69,7 @@ func (n *Notifier) notify() {
}
allRoutes := slices.Clone(n.currentRoutes)
if n.fakeIPRoute != nil {
allRoutes = append(allRoutes, n.fakeIPRoute)
}
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
@@ -85,6 +78,23 @@ func (n *Notifier) notify() {
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {

View File

@@ -34,10 +34,6 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on iOS
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}

View File

@@ -23,10 +23,6 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}

View File

@@ -1,10 +0,0 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -1,241 +0,0 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{})
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
localRoutes, err := hasSeparateRouting()
if err != nil {
log.Errorf("Failed to get routes: %v", err)
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}

View File

@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil

View File

@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,9 +25,6 @@ import (
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
)
var routeProtoFlag int
@@ -44,42 +41,26 @@ func init() {
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB()
if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
@@ -136,12 +117,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix)
}
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return fmt.Errorf("build route message: %w", err)
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
@@ -151,91 +132,50 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil
}
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
// kernel's matching reply, retrying transient failures until budget elapses.
// Callers do not need to manage sockets or seq numbers themselves.
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
}
}()
tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
return backoff.Permanent(fmt.Errorf("read: %w", err))
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
continue
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
// uniquely identifies our own reply among broadcast traffic.
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
continue
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
if hdr.Errno != 0 {
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -243,7 +183,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
}
@@ -282,6 +221,19 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {

View File

@@ -1,5 +0,0 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows
package net

24
client/net/env_android.go Normal file
View File

@@ -0,0 +1,24 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows && !android
package net

View File

@@ -1,25 +0,0 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || windows
//go:build windows
package net
@@ -24,22 +24,17 @@ func Init() {
}
func checkAdvancedRoutingSupport() bool {
legacyRouting := false
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
parsed, err := strconv.ParseBool(val)
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
if legacyRouting {
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}

View File

@@ -1,5 +0,0 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows
package net

View File

@@ -1,160 +0,0 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -4979,7 +4979,6 @@ type GetFeaturesResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"`
DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -5028,13 +5027,6 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
return false
}
func (x *GetFeaturesResponse) GetDisableNetworks() bool {
if x != nil {
return x.DisableNetworks
}
return false
}
type TriggerUpdateRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -6480,11 +6472,10 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" +
"\t_username\"\x10\n" +
"\x0eLogoutResponse\"\x14\n" +
"\x12GetFeaturesRequest\"\xa3\x01\n" +
"\x12GetFeaturesRequest\"x\n" +
"\x13GetFeaturesResponse\x12)\n" +
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" +
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" +
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" +
"\x14TriggerUpdateRequest\"M\n" +
"\x15TriggerUpdateResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +

View File

@@ -727,7 +727,6 @@ message GetFeaturesRequest{}
message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
bool disable_networks = 3;
}
message TriggerUpdateRequest {}

View File

@@ -9,8 +9,6 @@ import (
"strings"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
@@ -29,10 +27,6 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
@@ -124,10 +118,6 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
@@ -174,10 +164,6 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -53,7 +53,6 @@ const (
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
errNetworksDisabled = "network selection is disabled by the administrator"
)
var ErrServiceNotUp = errors.New("service is not up")
@@ -89,7 +88,6 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
@@ -106,7 +104,7 @@ type oauthAuthFlow struct {
}
// New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
s := &Server{
rootCtx: ctx,
logFile: logFile,
@@ -115,7 +113,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
networksDisabled: networksDisabled,
jwtCache: newJWTCache(),
}
agent := &serverAgent{s}
@@ -1631,7 +1628,6 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
DisableNetworks: s.networksDisabled,
}
return features, nil

View File

@@ -36,7 +36,6 @@ import (
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -104,7 +103,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false)
s := New(ctx, "debug", "", false, false)
s.config = config
@@ -165,7 +164,7 @@ func TestServer_Up(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false, false, false)
s := New(ctx, "console", "", false, false)
err = s.Start()
require.NoError(t, err)
@@ -235,7 +234,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false, false, false)
s := New(ctx, "console", "", false, false)
err = s.Start()
require.NoError(t, err)
@@ -310,12 +309,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -326,7 +320,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}

View File

@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false, false)
s := New(ctx, "console", "", false, false)
rosenpassEnabled := true
rosenpassPermissive := true

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -137,8 +138,10 @@ func restoreResidualState(ctx context.Context, statePath string) error {
}
// clean up any remaining routes independently of the state file
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -25,7 +25,6 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/netrelay"
)
const (
@@ -537,7 +536,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
continue
}
go c.handleLocalForward(ctx, localConn, remoteAddr)
go c.handleLocalForward(localConn, remoteAddr)
}
}()
@@ -549,7 +548,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
}
// handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local port forwarding: close local connection: %v", err)
@@ -572,7 +571,7 @@ func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, rem
}
}()
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -654,19 +653,16 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
select {
case <-ctx.Done():
return
case newChan, ok := <-channelRequests:
if !ok {
return
}
case newChan := <-channelRequests:
if newChan != nil {
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
go c.handleRemoteForwardChannel(newChan, localAddr)
}
}
}
}
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept()
if err != nil {
return
@@ -679,14 +675,8 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New
go ssh.DiscardRequests(reqs)
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
// channel open indefinitely; the relay itself runs under the outer ctx.
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
var dialer net.Dialer
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
cancelDial()
localConn, err := net.Dial("tcp", localAddr)
if err != nil {
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
return
}
defer func() {
@@ -695,7 +685,7 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New
}
}()
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -194,3 +194,63 @@ func buildAddressList(hostname string, remote net.Addr) []string {
return addresses
}
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -187,23 +187,24 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostList := strings.Join(deduplicatedPatterns, ",")
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}

View File

@@ -116,37 +116,6 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
}
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable
t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -353,7 +352,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
}
go cryptossh.DiscardRequests(clientReqs)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -592,7 +591,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
}
go cryptossh.DiscardRequests(clientReqs)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -17,7 +17,7 @@ import (
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/util/netrelay"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const privilegedPortThreshold = 1024
@@ -356,7 +356,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
return
}
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel

View File

@@ -10,7 +10,6 @@ import (
"net"
"net/netip"
"slices"
"strconv"
"strings"
"sync"
"time"
@@ -27,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -54,10 +52,6 @@ const (
DefaultJWTMaxTokenAge = 10 * 60
)
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
// the forwarded destination before rejecting the SSH channel.
const directTCPIPDialTimeout = 30 * time.Second
var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found")
@@ -897,29 +891,5 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
}
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
// DirectTCPIPHandler. The upstream handler closes both sides on the first
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
// own terms.
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
dest := net.JoinHostPort(host, strconv.Itoa(port))
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go cryptossh.DiscardRequests(reqs)
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}

View File

@@ -2,6 +2,7 @@ package system
import (
"context"
"net"
"net/netip"
"strings"
@@ -144,6 +145,59 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v
}
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,16 +2,12 @@ package system
import (
"context"
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
@@ -19,24 +15,11 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
// Convert fixed-size byte arrays to Go strings
sysName := extractOsName(ctx, "sysName")
swVersion := extractOsVersion(ctx, "swVersion")
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
gio := &Info{
Kernel: sysName,
OSVersion: swVersion,
Platform: "unknown",
OS: sysName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: swVersion,
NetworkAddresses: addrs,
}
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
gio.Hostname = extractDeviceName(ctx, "hostname")
gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
@@ -44,66 +27,6 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -1,66 +0,0 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

View File

@@ -314,7 +314,6 @@ type serviceClient struct {
lastNotifiedVersion string
settingsEnabled bool
profilesEnabled bool
networksEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -369,7 +368,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
networksEnabled: true,
}
s.eventHandler = newEventHandler(s)
@@ -922,10 +920,8 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetIcon(s.icConnectedDot)
s.mUp.Disable()
s.mDown.Enable()
if s.networksEnabled {
s.mNetworks.Enable()
s.mExitNode.Enable()
}
s.mNetworks.Enable()
s.mExitNode.Enable()
s.startExitNodeRefresh()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -1097,14 +1093,14 @@ func (s *serviceClient) onTrayReady() {
s.getSrvConfig()
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
for {
// Check features before status so menus respect disable flags before being enabled
s.checkAndUpdateFeatures()
err := s.updateStatus()
if err != nil {
log.Errorf("error while updating status: %v", err)
}
// Check features periodically to handle daemon restarts
s.checkAndUpdateFeatures()
time.Sleep(2 * time.Second)
}
}()
@@ -1303,16 +1299,6 @@ func (s *serviceClient) checkAndUpdateFeatures() {
s.mProfile.setEnabled(profilesEnabled)
}
}
// Update networks and exit node menus based on current features
s.networksEnabled = features == nil || !features.DisableNetworks
if s.networksEnabled && s.connected {
s.mNetworks.Enable()
s.mExitNode.Enable()
} else {
s.mNetworks.Disable()
s.mExitNode.Disable()
}
}
// getFeatures from the daemon to determine which features are enabled/disabled.

View File

@@ -179,11 +179,9 @@ type StoreConfig struct {
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
AccessLogRetentionDays int `yaml:"accessLogRetentionDays"`
AccessLogCleanupIntervalHours int `yaml:"accessLogCleanupIntervalHours"`
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
}
// DefaultConfig returns a CombinedConfig with default values
@@ -647,9 +645,7 @@ func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
AccessLogRetentionDays: mgmt.ReverseProxy.AccessLogRetentionDays,
AccessLogCleanupIntervalHours: mgmt.ReverseProxy.AccessLogCleanupIntervalHours,
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {

View File

@@ -14,6 +14,7 @@ import (
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -25,22 +26,11 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
var ErrClientClosed = errors.New("client is closed")
// minHealthyDuration is the minimum time a stream must survive before a failure
// resets the backoff timer. Streams that fail faster are considered unhealthy and
// should not reset backoff, so that MaxElapsedTime can eventually stop retries.
const minHealthyDuration = 5 * time.Second
type GRPCClient struct {
realClient proto.FlowServiceClient
clientConn *grpc.ClientConn
stream proto.FlowService_EventsClient
target string
opts []grpc.DialOption
closed bool // prevent creating conn in the middle of the Close
receiving bool // prevent concurrent Receive calls
mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving
streamMu sync.Mutex
}
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
@@ -75,8 +65,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
)
target := parsedURL.Host
conn, err := grpc.NewClient(target, opts...)
conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...)
if err != nil {
return nil, fmt.Errorf("creating new grpc client: %w", err)
}
@@ -84,73 +73,30 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
return &GRPCClient{
realClient: proto.NewFlowServiceClient(conn),
clientConn: conn,
target: target,
opts: opts,
}, nil
}
func (c *GRPCClient) Close() error {
c.mu.Lock()
c.closed = true
c.streamMu.Lock()
defer c.streamMu.Unlock()
c.stream = nil
conn := c.clientConn
c.clientConn = nil
c.mu.Unlock()
if conn == nil {
return nil
}
if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) {
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("close client connection: %w", err)
}
return nil
}
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.mu.Lock()
stream := c.stream
c.mu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
}
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
}
return nil
}
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
c.mu.Lock()
if c.receiving {
c.mu.Unlock()
return errors.New("concurrent Receive calls are not supported")
}
c.receiving = true
c.mu.Unlock()
defer func() {
c.mu.Lock()
c.receiving = false
c.mu.Unlock()
}()
backOff := defaultBackoff(ctx, interval)
operation := func() error {
stream, err := c.establishStream(ctx)
if err != nil {
log.Errorf("failed to establish flow stream, retrying: %v", err)
return c.handleRetryableError(err, time.Time{}, backOff)
}
streamStart := time.Now()
if err := c.receive(stream, msgHandler); err != nil {
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
}
log.Errorf("receive failed: %v", err)
return c.handleRetryableError(err, streamStart, backOff)
return fmt.Errorf("receive: %w", err)
}
return nil
}
@@ -162,106 +108,37 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan
return nil
}
// handleRetryableError resets the backoff timer if the stream was healthy long
// enough and recreates the underlying ClientConn so that gRPC's internal
// subchannel backoff does not accumulate and compete with our own retry timer.
// A zero streamStart means the stream was never established.
func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error {
if isContextDone(err) {
return backoff.Permanent(err)
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
if c.clientConn.GetState() == connectivity.Shutdown {
return errors.New("connection to flow receiver has been shut down")
}
var permErr *backoff.PermanentError
if errors.As(err, &permErr) {
return err
}
// Reset the backoff so the next retry starts with a short delay instead of
// continuing the already-elapsed timer. Only do this if the stream was healthy
// long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime.
if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration {
backOff.Reset()
}
if recreateErr := c.recreateConnection(); recreateErr != nil {
log.Errorf("recreate connection: %v", recreateErr)
return recreateErr
}
log.Infof("connection recreated, retrying stream")
return fmt.Errorf("retrying after error: %w", err)
}
func (c *GRPCClient) recreateConnection() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return backoff.Permanent(ErrClientClosed)
}
conn, err := grpc.NewClient(c.target, c.opts...)
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
c.mu.Unlock()
return fmt.Errorf("create new connection: %w", err)
return fmt.Errorf("create event stream: %w", err)
}
old := c.clientConn
c.clientConn = conn
c.realClient = proto.NewFlowServiceClient(conn)
c.stream = nil
c.mu.Unlock()
_ = old.Close()
return nil
}
func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, backoff.Permanent(ErrClientClosed)
}
cl := c.realClient
c.mu.Unlock()
// open stream outside the lock — blocking operation
stream, err := cl.Events(ctx)
err = stream.Send(&proto.FlowEvent{IsInitiator: true})
if err != nil {
return nil, fmt.Errorf("create event stream: %w", err)
}
streamReady := false
defer func() {
if !streamReady {
_ = stream.CloseSend()
}
}()
if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil {
return nil, fmt.Errorf("send initiator: %w", err)
log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err)
}
if err = checkHeader(stream); err != nil {
return nil, fmt.Errorf("check header: %w", err)
return fmt.Errorf("check header: %w", err)
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, backoff.Permanent(ErrClientClosed)
}
c.streamMu.Lock()
c.stream = stream
c.mu.Unlock()
streamReady = true
c.streamMu.Unlock()
return stream, nil
return c.receive(stream, msgHandler)
}
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
for {
msg, err := stream.Recv()
if err != nil {
return err
return fmt.Errorf("receive from stream: %w", err)
}
if msg.IsInitiator {
@@ -292,7 +169,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.5,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: interval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
@@ -301,12 +178,18 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
}, ctx)
}
func isContextDone(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.streamMu.Lock()
stream := c.stream
c.streamMu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
}
if s, ok := status.FromError(err); ok {
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
}
return false
return nil
}

View File

@@ -2,11 +2,8 @@ package client_test
import (
"context"
"encoding/binary"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
@@ -14,8 +11,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
flow "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
@@ -23,89 +18,21 @@ import (
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
listener *connTrackListener
closeStream chan struct{} // signal server to close the stream
handlerDone chan struct{} // signaled each time Events() exits
handlerStarted chan struct{} // signaled each time Events() begins
}
// connTrackListener wraps a net.Listener to track accepted connections
// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM.
type connTrackListener struct {
net.Listener
mu sync.Mutex
conns []net.Conn
}
func (l *connTrackListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.mu.Lock()
l.conns = append(l.conns, c)
l.mu.Unlock()
return c, nil
}
// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR
// (error code 0x1) on every tracked connection. This produces the exact error:
//
// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR
//
// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload):
//
// Length (3 bytes): 0x000004
// Type (1 byte): 0x03 (RST_STREAM)
// Flags (1 byte): 0x00
// Stream ID (4 bytes): target stream (must have bit 31 clear)
// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR)
func (l *connTrackListener) connCount() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.conns)
}
func (l *connTrackListener) sendRSTStream(streamID uint32) {
l.mu.Lock()
defer l.mu.Unlock()
frame := make([]byte, 13) // 9-byte header + 4-byte payload
// Length = 4 (3 bytes, big-endian)
frame[0], frame[1], frame[2] = 0, 0, 4
// Type = RST_STREAM (0x03)
frame[3] = 0x03
// Flags = 0
frame[4] = 0x00
// Stream ID (4 bytes, big-endian, bit 31 reserved = 0)
binary.BigEndian.PutUint32(frame[5:9], streamID)
// Error Code = PROTOCOL_ERROR (0x1)
binary.BigEndian.PutUint32(frame[9:13], 0x1)
for _, c := range l.conns {
_, _ = c.Write(frame)
}
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
}
func newTestServer(t *testing.T) *testServer {
rawListener, err := net.Listen("tcp", "127.0.0.1:0")
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listener := &connTrackListener{Listener: rawListener}
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: rawListener.Addr().String(),
listener: listener,
closeStream: make(chan struct{}, 1),
handlerDone: make(chan struct{}, 10),
handlerStarted: make(chan struct{}, 10),
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
@@ -124,23 +51,11 @@ func newTestServer(t *testing.T) *testServer {
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
defer func() {
select {
case s.handlerDone <- struct{}{}:
default:
}
}()
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
select {
case s.handlerStarted <- struct{}{}:
default:
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
@@ -176,8 +91,6 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
if err := stream.Send(ack); err != nil {
return err
}
case <-s.closeStream:
return status.Errorf(codes.Internal, "server closing stream")
case <-ctx.Done():
return ctx.Err()
}
@@ -197,13 +110,16 @@ func TestReceive(t *testing.T) {
assert.NoError(t, err, "failed to close flow")
})
var ackCount atomic.Int32
receivedAcks := make(map[string]bool)
receiveDone := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator && len(msg.EventId) > 0 {
if ackCount.Add(1) >= 3 {
id := string(msg.EventId)
receivedAcks[id] = true
if len(receivedAcks) >= 3 {
close(receiveDone)
}
}
@@ -214,11 +130,7 @@ func TestReceive(t *testing.T) {
}
}()
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
time.Sleep(500 * time.Millisecond)
for i := 0; i < 3; i++ {
eventID := uuid.New().String()
@@ -241,7 +153,7 @@ func TestReceive(t *testing.T) {
t.Fatal("timeout waiting for acks to be processed")
}
assert.Equal(t, int32(3), ackCount.Load())
assert.Equal(t, 3, len(receivedAcks))
}
func TestReceive_ContextCancellation(t *testing.T) {
@@ -342,195 +254,3 @@ func TestSend(t *testing.T) {
t.Fatal("timeout waiting for ack to be received by flow")
}
}
func TestNewClient_PermanentClose(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
err = client.Close()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
done := make(chan error, 1)
go func() {
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
}()
select {
case err := <-done:
require.ErrorIs(t, err, flow.ErrClientClosed)
case <-time.After(2 * time.Second):
t.Fatal("Receive did not return after Close — stuck in retry loop")
}
}
func TestNewClient_CloseVerify(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
done := make(chan error, 1)
go func() {
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
}()
closeDone := make(chan struct{}, 1)
go func() {
_ = client.Close()
closeDone <- struct{}{}
}()
select {
case err := <-done:
require.Error(t, err)
case <-time.After(2 * time.Second):
t.Fatal("Receive did not return after Close — stuck in retry loop")
}
select {
case <-closeDone:
return
case <-time.After(2 * time.Second):
t.Fatal("Close did not return — blocked in retry loop")
}
}
func TestClose_WhileReceiving(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
ctx := context.Background() // no timeout — intentional
receiveDone := make(chan struct{})
go func() {
_ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
close(receiveDone)
}()
// Wait for the server-side handler to confirm the stream is established.
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
closeDone := make(chan struct{})
go func() {
_ = client.Close()
close(closeDone)
}()
select {
case <-closeDone:
// Close returned — good
case <-time.After(2 * time.Second):
t.Fatal("Close blocked forever — Receive stuck in retry loop")
}
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Fatal("Receive did not exit after Close")
}
}
func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
// Track acks received before and after server-side stream close
var ackCount atomic.Int32
receivedFirst := make(chan struct{})
receivedAfterReconnect := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 {
return nil
}
n := ackCount.Add(1)
if n == 1 {
close(receivedFirst)
}
if n == 2 {
close(receivedAfterReconnect)
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
// Wait for stream to be established, then send first ack
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")}
select {
case <-receivedFirst:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for first ack")
}
// Snapshot connection count before injecting the fault.
connsBefore := server.listener.connCount()
// Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection.
// gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated).
// Stream ID 1 is the client's first stream (our Events bidi stream).
// This produces the exact error the client sees in production:
// "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR"
server.listener.sendRSTStream(1)
// Wait for the old Events() handler to fully exit so it can no longer
// drain s.acks and drop our injected ack on a broken stream.
select {
case <-server.handlerDone:
case <-time.After(5 * time.Second):
t.Fatal("old Events() handler did not exit after RST_STREAM")
}
require.Eventually(t, func() bool {
return server.listener.connCount() > connsBefore
}, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")
server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")}
select {
case <-receivedAfterReconnect:
// Client successfully reconnected and received ack after server-side stream close
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect")
}
assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close")
assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)")
}

Some files were not shown because too many files have changed in this diff Show More