mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 07:26:14 -04:00
Compare commits
46 Commits
refactor/f
...
macos-ip-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c185b89e1 | ||
|
|
0a928f1dc4 | ||
|
|
8ae8f2098f | ||
|
|
a39787d679 | ||
|
|
53b04e512a | ||
|
|
633dde8d1f | ||
|
|
7e4542adde | ||
|
|
d4c61ed38b | ||
|
|
6b540d145c | ||
|
|
08f624507d | ||
|
|
95bc01e48f | ||
|
|
0d86de47df | ||
|
|
e804a705b7 | ||
|
|
46fc8c9f65 | ||
|
|
d7ad908962 | ||
|
|
c5623307cc | ||
|
|
7f666b8022 | ||
|
|
0a30b9b275 | ||
|
|
4eed459f27 | ||
|
|
13539543af | ||
|
|
7483fec048 | ||
|
|
5259e5df51 | ||
|
|
ebd78e0122 | ||
|
|
cf86b9a528 | ||
|
|
ee588e1536 | ||
|
|
2a8aacc5c9 | ||
|
|
15709bc666 | ||
|
|
789b4113fe | ||
|
|
d2cdc0efec | ||
|
|
ee343d5d77 | ||
|
|
099c493b18 | ||
|
|
c1d1229ae0 | ||
|
|
94a36cb53e | ||
|
|
c7ba931466 | ||
|
|
413d95b740 | ||
|
|
332c624c55 | ||
|
|
dc160aff36 | ||
|
|
96806bf55f | ||
|
|
d33cd4c95b | ||
|
|
e2c2f64be7 | ||
|
|
cb73b94ffb | ||
|
|
1d920d700c | ||
|
|
bb85eee40a | ||
|
|
aba5d6f0d2 | ||
|
|
0588d2dbe1 | ||
|
|
14b3b77bda |
62
.github/workflows/proto-version-check.yml
vendored
Normal file
62
.github/workflows/proto-version-check.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
name: Proto Version Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "**/*.pb.go"
|
||||
|
||||
jobs:
|
||||
check-proto-versions:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number,
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||
if (missingPatch.length > 0) {
|
||||
core.setFailed(
|
||||
`Cannot inspect patch data for:\n` +
|
||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const violations = [];
|
||||
|
||||
for (const file of pbFiles) {
|
||||
const changed = file.patch
|
||||
.split('\n')
|
||||
.filter(line => versionPattern.test(line));
|
||||
if (changed.length > 0) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
lines: changed,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (violations.length > 0) {
|
||||
const details = violations.map(v =>
|
||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||
).join('\n\n');
|
||||
|
||||
core.setFailed(
|
||||
`Proto version strings changed in generated files.\n` +
|
||||
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
|
||||
`Regenerate with the matching tool versions.\n\n` +
|
||||
details
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('No proto version string changes detected');
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.1"
|
||||
SIGN_PIPE_VER: "v0.1.2"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
needsRestoreUp := false
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = !stateWasDown
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = false
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
|
||||
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if needsRestoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up (restored)")
|
||||
}
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -201,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
@@ -236,7 +237,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
|
||||
@@ -75,6 +75,7 @@ var (
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
|
||||
@@ -44,10 +44,13 @@ func init() {
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||
|
||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
|
||||
`Use --service-env "" to clear all saved env vars. ` +
|
||||
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
}
|
||||
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
|
||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-update-settings")
|
||||
}
|
||||
|
||||
if networksDisabled {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ type serviceParams struct {
|
||||
LogFiles []string `json:"log_files,omitempty"`
|
||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
@@ -78,11 +79,12 @@ func currentServiceParams() *serviceParams {
|
||||
LogFiles: logFiles,
|
||||
DisableProfiles: profilesDisabled,
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err == nil && len(parsed) > 0 {
|
||||
if err == nil {
|
||||
params.ServiceEnvVars = parsed
|
||||
}
|
||||
}
|
||||
@@ -142,31 +144,46 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
updateSettingsDisabled = params.DisableUpdateSettings
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
// applyServiceEnvParams merges saved service environment variables.
|
||||
// If --service-env was explicitly set, explicit values win on key conflict
|
||||
// but saved keys not in the explicit set are carried over.
|
||||
// If --service-env was explicitly set with values, explicit values win on key
|
||||
// conflict but saved keys not in the explicit set are carried over.
|
||||
// If --service-env was explicitly set to empty, all saved env vars are cleared.
|
||||
// If --service-env was not set, saved env vars are used entirely.
|
||||
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||
if len(params.ServiceEnvVars) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !cmd.Flags().Changed("service-env") {
|
||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||
if len(params.ServiceEnvVars) > 0 {
|
||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Explicit env vars were provided: merge saved values underneath.
|
||||
// Flag was explicitly set: parse what the user provided.
|
||||
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the user passed an empty value (e.g. --service-env ""), clear all
|
||||
// saved env vars rather than merging.
|
||||
if len(explicit) == 0 {
|
||||
serviceEnvVars = nil
|
||||
return
|
||||
}
|
||||
|
||||
if len(params.ServiceEnvVars) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Merge saved values underneath explicit ones.
|
||||
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||
maps.Copy(merged, params.ServiceEnvVars)
|
||||
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||
|
||||
@@ -327,6 +327,41 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||
}
|
||||
|
||||
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Simulate --service-env "" which produces [""] in the slice.
|
||||
serviceEnvVars = []string{""}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().StringSlice("service-env", nil, "")
|
||||
require.NoError(t, cmd.Flags().Set("service-env", ""))
|
||||
|
||||
saved := &serviceParams{
|
||||
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, saved)
|
||||
|
||||
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
|
||||
}
|
||||
|
||||
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
|
||||
origServiceEnvVars := serviceEnvVars
|
||||
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||
|
||||
// Simulate --service-env "" which produces [""] in the slice.
|
||||
serviceEnvVars = []string{""}
|
||||
|
||||
params := currentServiceParams()
|
||||
|
||||
// After parsing, the empty string is skipped, resulting in an empty map.
|
||||
// The map should still be set (not nil) so it overwrites saved values.
|
||||
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
|
||||
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
|
||||
}
|
||||
|
||||
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||
// added to serviceParams but not wired into these functions, this test fails.
|
||||
@@ -500,6 +535,7 @@ func fieldToGlobalVar(field string) string {
|
||||
"LogFiles": "logFiles",
|
||||
"DisableProfiles": "profilesDisabled",
|
||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||
"DisableNetworks": "networksDisabled",
|
||||
"ServiceEnvVars": "serviceEnvVars",
|
||||
}
|
||||
if v, ok := m[field]; ok {
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
@@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||
ctx := context.Background()
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
@@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -152,7 +160,7 @@ func startClientDaemon(
|
||||
s := grpc.NewServer()
|
||||
|
||||
server := client.New(ctx,
|
||||
"", "", false, false)
|
||||
"", "", false, false, false)
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
@@ -35,20 +36,34 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||
log.Info("forcing userspace firewall")
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
|
||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||
// needs a device filter for DNS interception hooks. Install a minimal
|
||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
@@ -160,3 +175,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func forceUserspaceFirewall() bool {
|
||||
val := os.Getenv(EnvForceUserspaceFirewall)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
force, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
||||
return false
|
||||
}
|
||||
return force
|
||||
}
|
||||
|
||||
@@ -7,6 +7,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
|
||||
// native iptables/nftables is available. This only applies when the WireGuard interface
|
||||
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
|
||||
// kernel netfilter rules.
|
||||
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
|
||||
@@ -21,6 +21,10 @@ const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||
// external DNAT from bypassing ACL rules.
|
||||
mangleFwdKey = "MANGLE-FORWARD"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
@@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
@@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error {
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chainName == mangleFwdKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
@@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error {
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() {
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||
// ACL mark check where it cannot be overridden.
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
|
||||
@@ -33,7 +33,6 @@ type Manager struct {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
@@ -64,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -203,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
@@ -286,6 +282,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
|
||||
@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
@@ -43,6 +44,7 @@ const (
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
sync.Mutex
|
||||
|
||||
|
||||
@@ -169,6 +169,14 @@ type Manager interface {
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
@@ -40,7 +40,6 @@ func getTableName() string {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -346,6 +342,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
|
||||
@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
}
|
||||
|
||||
37
client/firewall/uspfilter/common/hooks.go
Normal file
37
client/firewall/uspfilter/common/hooks.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// PacketHook stores a registered hook for a specific IP:port.
|
||||
type PacketHook struct {
|
||||
IP netip.Addr
|
||||
Port uint16
|
||||
Fn func([]byte) bool
|
||||
}
|
||||
|
||||
// HookMatches checks if a packet's destination matches the hook and invokes it.
|
||||
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
if h.IP == dstIP && h.Port == dport {
|
||||
return h.Fn(packetData)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetHook atomically stores a hook, handling nil removal.
|
||||
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||
if hook == nil {
|
||||
ptr.Store(nil)
|
||||
return
|
||||
}
|
||||
ptr.Store(&PacketHook{
|
||||
IP: ip,
|
||||
Port: dPort,
|
||||
Fn: hook,
|
||||
})
|
||||
}
|
||||
@@ -140,6 +140,10 @@ type Manager struct {
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[common.PacketHook]
|
||||
tcpHookOut atomic.Pointer[common.PacketHook]
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -594,6 +598,8 @@ func (m *Manager) resetState() {
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -713,6 +719,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
@@ -895,39 +904,12 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// filterInbound implements filtering logic for incoming packets.
|
||||
@@ -1278,12 +1260,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
@@ -1342,65 +1318,14 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
udpHook: hook,
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
common.SetHook(&m.udpHookOut, ip, dPort, hook)
|
||||
}
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Check incoming hooks (stored in allow rules)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check outgoing hooks
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
}
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
var called bool
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||
assert.Equal(t, uint16(8000), h.Port)
|
||||
assert.True(t, h.Fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load())
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
var called bool
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||
assert.Equal(t, uint16(53), h.Port)
|
||||
assert.True(t, h.Fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
assert.Nil(t, manager.tcpHookOut.Load())
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||
|
||||
if !found {
|
||||
t.Fatalf("The hook was not added properly.")
|
||||
}
|
||||
|
||||
// Now remove the packet hook
|
||||
err = manager.RemovePacketHook(hookID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove hook: %s", err)
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
manager.SetUDPPacketHook(
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
|
||||
90
client/firewall/uspfilter/hooks_filter.go
Normal file
90
client/firewall/uspfilter/hooks_filter.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4HeaderMinLen = 20
|
||||
ipv4ProtoOffset = 9
|
||||
ipv4FlagsOffset = 6
|
||||
ipv4DstOffset = 16
|
||||
ipProtoUDP = 17
|
||||
ipProtoTCP = 6
|
||||
ipv4FragOffMask = 0x1fff
|
||||
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
|
||||
dstPortOffset = 2
|
||||
)
|
||||
|
||||
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
|
||||
// It is installed on the WireGuard interface when the userspace bind is active
|
||||
// but a full firewall filter (Manager) is not needed because a native kernel
|
||||
// firewall (nftables/iptables) handles packet filtering.
|
||||
type HooksFilter struct {
|
||||
udpHook atomic.Pointer[common.PacketHook]
|
||||
tcpHook atomic.Pointer[common.PacketHook]
|
||||
}
|
||||
|
||||
var _ device.PacketFilter = (*HooksFilter)(nil)
|
||||
|
||||
// FilterOutbound checks outbound packets for DNS hook matches.
|
||||
// Only IPv4 packets matching the registered hook IP:port are intercepted.
|
||||
// IPv6 and non-IP packets pass through unconditionally.
|
||||
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
|
||||
if len(packetData) < ipv4HeaderMinLen {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only process IPv4 packets, let everything else pass through.
|
||||
if packetData[0]>>4 != 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
ihl := int(packetData[0]&0x0f) * 4
|
||||
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip non-first fragments: they don't carry L4 headers.
|
||||
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
|
||||
if flagsAndOffset&ipv4FragOffMask != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
proto := packetData[ipv4ProtoOffset]
|
||||
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
|
||||
|
||||
switch proto {
|
||||
case ipProtoUDP:
|
||||
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
|
||||
case ipProtoTCP:
|
||||
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// FilterInbound allows all inbound packets (native firewall handles filtering).
|
||||
func (f *HooksFilter) FilterInbound([]byte, int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetUDPPacketHook registers the UDP packet hook.
|
||||
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||
common.SetHook(&f.udpHook, ip, dPort, hook)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook registers the TCP packet hook.
|
||||
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||
common.SetHook(&f.tcpHook, ip, dPort, hook)
|
||||
}
|
||||
@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
|
||||
@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
|
||||
@@ -18,9 +18,7 @@ type PeerRule struct {
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
|
||||
udpHook func([]byte) bool
|
||||
drop bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||
return true // drop (intercepted by hook)
|
||||
})
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
|
||||
@@ -15,14 +15,17 @@ type PacketFilter interface {
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
// SetUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketFilter is a mock of PacketFilter interface.
|
||||
type MockPacketFilter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketFilterMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||
type MockPacketFilterMockRecorder struct {
|
||||
mock *MockPacketFilter
|
||||
}
|
||||
|
||||
// NewMockPacketFilter creates a new mock instance.
|
||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||
mock := &MockPacketFilter{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
@@ -19,6 +19,9 @@ import (
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
func TestDefaultManagerStateless(t *testing.T) {
|
||||
// stateless currently only in userspace, so we have to disable kernel
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -111,6 +111,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -120,6 +121,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
||||
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addServiceParams(); err != nil {
|
||||
log.Errorf("failed to add service params to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addMetrics(); err != nil {
|
||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||
}
|
||||
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
serviceParamsFile = "service.json"
|
||||
serviceParamsBundle = "service_params.json"
|
||||
maskedValue = "***"
|
||||
envVarPrefix = "NB_"
|
||||
jsonKeyManagementURL = "management_url"
|
||||
jsonKeyServiceEnv = "service_env_vars"
|
||||
)
|
||||
|
||||
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
||||
|
||||
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
||||
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
||||
func (g *BundleGenerator) addServiceParams() error {
|
||||
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read service params: %w", err)
|
||||
}
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return fmt.Errorf("parse service params: %w", err)
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
||||
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
||||
}
|
||||
}
|
||||
|
||||
g.sanitizeServiceEnvVars(params)
|
||||
|
||||
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sanitized service params: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
||||
return fmt.Errorf("add service params to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
||||
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
||||
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
||||
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
||||
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := make(map[string]any, len(envVars))
|
||||
for k, v := range envVars {
|
||||
val, _ := v.(string)
|
||||
switch {
|
||||
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
||||
sanitized[k] = maskedValue
|
||||
case g.anonymize:
|
||||
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
||||
default:
|
||||
sanitized[k] = val
|
||||
}
|
||||
}
|
||||
params[jsonKeyServiceEnv] = sanitized
|
||||
}
|
||||
|
||||
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
||||
func isSensitiveEnvVar(key string) bool {
|
||||
lower := strings.ToLower(key)
|
||||
for _, s := range sensitiveEnvSubstrings {
|
||||
if strings.Contains(lower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -10,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveEnvVar(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
sensitive bool
|
||||
}{
|
||||
{"NB_SETUP_KEY", true},
|
||||
{"NB_API_TOKEN", true},
|
||||
{"NB_CLIENT_SECRET", true},
|
||||
{"NB_PASSWORD", true},
|
||||
{"NB_CREDENTIAL", true},
|
||||
{"NB_LOG_LEVEL", false},
|
||||
{"NB_MANAGEMENT_URL", false},
|
||||
{"NB_HOSTNAME", false},
|
||||
{"HOME", false},
|
||||
{"PATH", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anonymize bool
|
||||
input map[string]any
|
||||
check func(t *testing.T, params map[string]any)
|
||||
}{
|
||||
{
|
||||
name: "no env vars key",
|
||||
anonymize: false,
|
||||
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
||||
_, ok := params[jsonKeyServiceEnv]
|
||||
assert.False(t, ok, "service_env_vars should not be added")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sensitive NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safe NB vars anonymized when anonymize is true",
|
||||
anonymize: true,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
"NB_SETUP_KEY": "secret",
|
||||
"SOME_OTHER": "val",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
// Safe NB_ values should be anonymized (not the original, not masked)
|
||||
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
||||
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
||||
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
||||
|
||||
logVal := env["NB_LOG_LEVEL"].(string)
|
||||
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
||||
|
||||
// Sensitive and non-NB_ still masked
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
||||
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
g := &BundleGenerator{
|
||||
anonymize: tt.anonymize,
|
||||
anonymizer: anonymizer,
|
||||
}
|
||||
g.sanitizeServiceEnvVars(tt.input)
|
||||
tt.check(t, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddServiceParams(t *testing.T) {
|
||||
t.Run("missing service.json returns nil", func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
}
|
||||
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = t.TempDir()
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
err := g.addServiceParams()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_LOG_LEVEL": "trace",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: true,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, zr.File, 1)
|
||||
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
mgmt := result[jsonKeyManagementURL].(string)
|
||||
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
||||
assert.NotEmpty(t, mgmt)
|
||||
|
||||
env := result[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
||||
})
|
||||
|
||||
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: false,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
|
||||
@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
if m.MsgHdr.Truncated {
|
||||
w.SetMeta("truncated", "true")
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
|
||||
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
// TestLocalResolver_Stop tests cleanup on GracefullyStop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
t.Run("GracefullyStop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
|
||||
@@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
||||
func (m *MockServer) SetFirewall(Firewall) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
|
||||
@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
|
||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
||||
var srcIP net.IP
|
||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
||||
}
|
||||
|
||||
var srcPort int
|
||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
||||
}
|
||||
|
||||
if srcIP == nil {
|
||||
return nil
|
||||
}
|
||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ type Server interface {
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
SetRouteChecker(func(netip.Addr) bool)
|
||||
SetFirewall(Firewall)
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
@@ -186,11 +187,16 @@ func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
hostsDnsList []netip.AddrPort,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -374,6 +380,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
||||
// This must be called before Initialize when using the listener-based service,
|
||||
// because the firewall is typically not available at construction time.
|
||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
||||
svc.listenerFlagLock.Lock()
|
||||
svc.firewall = fw
|
||||
svc.listenerFlagLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
@@ -395,8 +412,12 @@ func (s *DefaultServer) Stop() {
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) disableDNS() error {
|
||||
defer s.service.Stop()
|
||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||
defer func() {
|
||||
if err := s.service.Stop(); err != nil {
|
||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if s.isUsingNoopHostManager() {
|
||||
return nil
|
||||
|
||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) Stop() error { return nil }
|
||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
|
||||
@@ -4,15 +4,25 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
||||
// This is used when the DNS server cannot bind port 53 directly
|
||||
// and needs firewall rules to redirect traffic.
|
||||
type Firewall interface {
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
Stop() error
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
@@ -31,25 +35,33 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
tcpServer *dns.Server
|
||||
listenIP netip.Addr
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
firewall Firewall
|
||||
tcpDNATConfigured bool
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
firewall: fw,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
tcpServer: &dns.Server{
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
s.server.Addr = addr
|
||||
s.tcpServer.Addr = addr
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
||||
s.listenerIsRunning = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
|
||||
s.listenerFlagLock.Lock()
|
||||
unexpected := s.listenerIsRunning
|
||||
s.listenerIsRunning = false
|
||||
s.listenerFlagLock.Unlock()
|
||||
|
||||
if unexpected {
|
||||
if err := s.tcpServer.Shutdown(); err != nil {
|
||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
||||
// a DNAT rule because eBPF only handles UDP.
|
||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
||||
} else {
|
||||
s.tcpDNATConfigured = true
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
func (s *serviceViaListener) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
s.listenerIsRunning = false
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
||||
}
|
||||
|
||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
||||
}
|
||||
|
||||
if s.tcpDNATConfigured && s.firewall != nil {
|
||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
s.tcpDNATConfigured = false
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
if err := udpLn.Close(); err != nil {
|
||||
log.Debugf("close UDP probe listener: %s", err)
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
if err := tcpLn.Close(); err != nil {
|
||||
log.Debugf("close TCP probe listener: %s", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
86
client/internal/dns/service_listener_test.go
Normal file
86
client/internal/dns/service_listener_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("192.0.2.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a service using a custom address to avoid needing root
|
||||
svc := newServiceViaListener(nil, nil, nil)
|
||||
svc.dnsMux.Handle(".", handler)
|
||||
|
||||
// Bind both transports up front to avoid TOCTOU races.
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
||||
}
|
||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
t.Skip("cannot bind TCP on same port, skipping")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
||||
svc.server.PacketConn = udpConn
|
||||
svc.tcpServer.Listener = tcpLn
|
||||
svc.listenIP = customIP
|
||||
svc.listenPort = port
|
||||
|
||||
go func() {
|
||||
if err := svc.server.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
svc.listenerIsRunning = true
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, svc.Stop())
|
||||
}()
|
||||
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Test UDP query
|
||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "UDP query should succeed")
|
||||
require.NotNil(t, udpResp)
|
||||
require.NotEmpty(t, udpResp.Answer)
|
||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
||||
|
||||
// Test TCP query
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "TCP query should succeed")
|
||||
require.NotNil(t, tcpResp)
|
||||
require.NotEmpty(t, tcpResp.Answer)
|
||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP netip.Addr
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
tcpDNS *tcpDNSServer
|
||||
tcpHookSet bool
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
|
||||
return &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
if err := s.filterDNSTraffic(); err != nil {
|
||||
return fmt.Errorf("filter dns traffic: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter != nil {
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
if s.tcpHookSet {
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
s.tcpDNS.Stop()
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
return errors.New("DNS filter not initialized")
|
||||
}
|
||||
|
||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
||||
if s.tcpDNS == nil {
|
||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
||||
}
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
if udpLayer == nil {
|
||||
return true
|
||||
}
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
dev := s.wgInterface.GetDevice()
|
||||
if dev == nil {
|
||||
return true
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
|
||||
writer := &responseWriter{
|
||||
remote: remoteAddrFromPacket(packet),
|
||||
packet: packet,
|
||||
device: dev.Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
tcpHook := func(packetData []byte) bool {
|
||||
s.tcpDNS.InjectPacket(packetData)
|
||||
return true
|
||||
}
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
||||
s.tcpHookSet = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
444
client/internal/dns/tcpstack.go
Normal file
444
client/internal/dns/tcpstack.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTCPReceiveWindow = 8192
|
||||
dnsTCPMaxInFlight = 16
|
||||
dnsTCPIdleTimeout = 30 * time.Second
|
||||
dnsTCPReadTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
||||
// after a period of inactivity to conserve resources.
|
||||
type tcpDNSServer struct {
|
||||
mu sync.Mutex
|
||||
s *stack.Stack
|
||||
ep *dnsEndpoint
|
||||
mux *dns.ServeMux
|
||||
tunDev tun.Device
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
mtu uint16
|
||||
|
||||
running bool
|
||||
closed bool
|
||||
timerID uint64
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
||||
return &tcpDNSServer{
|
||||
mux: mux,
|
||||
tunDev: tunDev,
|
||||
ip: ip,
|
||||
port: port,
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
||||
// lock prevents a race where the idle timer could stop the stack between
|
||||
// start and delivery.
|
||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.running {
|
||||
if err := t.startLocked(); err != nil {
|
||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
||||
}
|
||||
t.resetTimerLocked()
|
||||
|
||||
ep := t.ep
|
||||
if ep == nil || ep.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
|
||||
// Stop tears down the gvisor stack and releases resources permanently.
|
||||
// After Stop, InjectPacket becomes a no-op.
|
||||
func (t *tcpDNSServer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.stopLocked()
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) startLocked() error {
|
||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
nicID := tcpip.NICID(1)
|
||||
ep := &dnsEndpoint{
|
||||
tunDev: t.tunDev,
|
||||
}
|
||||
ep.mtu.Store(uint32(t.mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
||||
PrefixLen: 32,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create default subnet: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
})
|
||||
|
||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
||||
t.handleTCPDNS(r)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
t.s = s
|
||||
t.ep = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) stopLocked() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
|
||||
if t.s != nil {
|
||||
t.s.Close()
|
||||
t.s.Wait()
|
||||
t.s = nil
|
||||
}
|
||||
t.ep = nil
|
||||
t.running = false
|
||||
|
||||
log.Debugf("TCP DNS stack stopped")
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) resetTimerLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
t.timerID++
|
||||
id := t.timerID
|
||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Only stop if this timer is still the active one.
|
||||
// A racing InjectPacket may have replaced it.
|
||||
if t.timerID != id {
|
||||
return
|
||||
}
|
||||
t.stopLocked()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Tracef("TCP DNS: close conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reset idle timer on activity
|
||||
t.mu.Lock()
|
||||
t.resetTimerLocked()
|
||||
t.mu.Unlock()
|
||||
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: id.LocalAddress.AsSlice(),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
remoteAddr := &net.TCPAddr{
|
||||
IP: id.RemoteAddress.AsSlice(),
|
||||
Port: int(id.RemotePort),
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := readTCPDNSMessage(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
writer := &tcpResponseWriter{
|
||||
conn: conn,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
t.mux.ServeDNS(writer, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
||||
type dnsEndpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
tunDev tun.Device
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
||||
|
||||
const tunPacketOffset = 40
|
||||
|
||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := data.AsSlice()
|
||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
||||
buf = append(buf, raw...)
|
||||
data.Release()
|
||||
|
||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
||||
type tcpResponseWriter struct {
|
||||
conn *gonet.TCPConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
||||
return w.localAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
|
||||
// DNS TCP: 2-byte length prefix + message
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
|
||||
if _, err = w.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
if _, err := w.conn.Write(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
||||
|
||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
||||
// DNS over TCP uses a 2-byte length prefix
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return nil, fmt.Errorf("read length: %w", err)
|
||||
}
|
||||
|
||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if msgLen == 0 || msgLen > 65535 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("unpack: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
||||
// Supports both IPv4 and IPv6.
|
||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
||||
if len(pkt) == 0 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
||||
}
|
||||
|
||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
||||
switch header.IPVersion(pkt) {
|
||||
case 4:
|
||||
return srcIPv4(pkt)
|
||||
case 6:
|
||||
return srcIPv6(pkt)
|
||||
default:
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
}
|
||||
|
||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv4MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv4(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, int(hdr.HeaderLength())
|
||||
}
|
||||
|
||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv6MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv6(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, header.IPv6MinimumSize
|
||||
}
|
||||
@@ -41,10 +41,61 @@ const (
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
}
|
||||
|
||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
||||
func dnsProtocolFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
|
||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
@@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
||||
// when the request came in over TCP.
|
||||
ctx := u.ctx
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
network := addr.Network()
|
||||
ctx = contextWithDNSProtocol(ctx, network)
|
||||
resutil.SetMeta(w, "protocol", network)
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
@@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
@@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
@@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
@@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
return int(opt.UDPSize())
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
// Note: the query could be sent out on an interface that is not ours,
|
||||
// but higher MTU settings could break truncation handling.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
client.UDPSize = maxUDPPayload
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
@@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// TODO: if the upstream's truncated UDP response already contains more
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
client.Net = "tcp"
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
@@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
// If request came in over TCP, go straight to TCP upstream
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than what we can read over UDP.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
@@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
|
||||
@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSProtocolContext(t *testing.T) {
|
||||
t.Run("roundtrip udp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("missing returns empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
||||
// Start a local DNS server that responds on TCP only
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Addr: "127.0.0.1:0",
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tcpServer.Listener = tcpLn
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// With TCP context, should connect directly via TCP without trying UDP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
||||
// UDP handler returns a truncated response to trigger TCP retry.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// TCP handler returns the full answer.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.3"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{
|
||||
PacketConn: udpPC,
|
||||
Net: "udp",
|
||||
Handler: udpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := udpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.2"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// TCP context: should skip UDP entirely and go directly to TCP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
||||
|
||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
||||
ctx2 := context.Background()
|
||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
||||
// truncates, the TCP response should NOT be truncated since the full
|
||||
// answer fits within the client's original buffer.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
// Add enough records to exceed MTU but fit within 4096
|
||||
for i := range 20 {
|
||||
m.Answer = append(m.Answer, &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
||||
})
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
||||
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
|
||||
// Client with large buffer: should get all records without truncation
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
||||
|
||||
// Client with small buffer: should get truncated response
|
||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r2.SetEdns0(512, false)
|
||||
|
||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm2)
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
@@ -210,9 +211,10 @@ type Engine struct {
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
@@ -259,26 +261,27 @@ func NewEngine(
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -521,6 +524,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject firewall into DNS server now that it's available.
|
||||
// The DNS server is created before the firewall because the route manager
|
||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
||||
e.dnsServer.SetFirewall(e.firewall)
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
@@ -532,6 +540,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Start after interface is up since port may have been resolved from 0 or changed if occupied
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
|
||||
}()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -554,7 +569,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
e.srWatcher.Start()
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -1535,12 +1550,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
PortForwardManager: e.portForwardManager,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1697,6 +1713,12 @@ func (e *Engine) close() {
|
||||
if e.rpManager != nil {
|
||||
_ = e.rpManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
@@ -1800,7 +1822,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
@@ -1837,6 +1859,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
// IsBlockInbound returns whether inbound connections are blocked.
|
||||
func (e *Engine) IsBlockInbound() bool {
|
||||
return e.config.BlockInbound
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
|
||||
@@ -55,6 +55,7 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -1634,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -1656,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
nfct "github.com/ti-mo/conntrack"
|
||||
@@ -17,31 +19,64 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
const (
|
||||
defaultChannelSize = 100
|
||||
reconnectInitInterval = 5 * time.Second
|
||||
reconnectMaxInterval = 5 * time.Minute
|
||||
reconnectRandomization = 0.5
|
||||
)
|
||||
|
||||
// listener abstracts a netlink conntrack connection for testability.
|
||||
type listener interface {
|
||||
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ConnTrack manages kernel-based conntrack events
|
||||
type ConnTrack struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
iface nftypes.IFaceMapper
|
||||
|
||||
conn *nfct.Conn
|
||||
conn listener
|
||||
mux sync.Mutex
|
||||
|
||||
dial func() (listener, error)
|
||||
instanceID uuid.UUID
|
||||
started bool
|
||||
done chan struct{}
|
||||
sysctlModified bool
|
||||
}
|
||||
|
||||
// DialFunc is a constructor for netlink conntrack connections.
|
||||
type DialFunc func() (listener, error)
|
||||
|
||||
// Option configures a ConnTrack instance.
|
||||
type Option func(*ConnTrack)
|
||||
|
||||
// WithDialer overrides the default netlink dialer, primarily for testing.
|
||||
func WithDialer(dial DialFunc) Option {
|
||||
return func(c *ConnTrack) {
|
||||
c.dial = dial
|
||||
}
|
||||
}
|
||||
|
||||
func defaultDial() (listener, error) {
|
||||
return nfct.Dial(nil)
|
||||
}
|
||||
|
||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||
return &ConnTrack{
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
|
||||
ct := &ConnTrack{
|
||||
flowLogger: flowLogger,
|
||||
iface: iface,
|
||||
instanceID: uuid.New(),
|
||||
started: false,
|
||||
dial: defaultDial,
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(ct)
|
||||
}
|
||||
return ct
|
||||
}
|
||||
|
||||
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||
@@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
||||
c.EnableAccounting()
|
||||
}
|
||||
|
||||
conn, err := nfct.Dial(nil)
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
c.RestoreAccounting()
|
||||
return fmt.Errorf("dial conntrack: %w", err)
|
||||
}
|
||||
c.conn = conn
|
||||
@@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
||||
log.Errorf("Error closing conntrack connection: %v", err)
|
||||
}
|
||||
c.conn = nil
|
||||
c.RestoreAccounting()
|
||||
return fmt.Errorf("start conntrack listener: %w", err)
|
||||
}
|
||||
|
||||
// Drain any stale stop signal from a previous cycle.
|
||||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
}
|
||||
|
||||
c.started = true
|
||||
|
||||
go c.receiverRoutine(events, errChan)
|
||||
@@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error)
|
||||
case event := <-events:
|
||||
c.handleEvent(event)
|
||||
case err := <-errChan:
|
||||
log.Errorf("Error from conntrack event listener: %v", err)
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Errorf("Error closing conntrack connection: %v", err)
|
||||
if events, errChan = c.handleListenerError(err); events == nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleListenerError closes the failed connection and attempts to reconnect.
|
||||
// Returns new channels on success, or nil if shutdown was requested.
|
||||
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
|
||||
log.Warnf("conntrack event listener failed: %v", err)
|
||||
c.closeConn()
|
||||
return c.reconnect()
|
||||
}
|
||||
|
||||
func (c *ConnTrack) closeConn() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
log.Debugf("close conntrack connection: %v", err)
|
||||
}
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
|
||||
// Returns new channels on success, or nil if shutdown was requested.
|
||||
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: reconnectInitInterval,
|
||||
RandomizationFactor: reconnectRandomization,
|
||||
Multiplier: backoff.DefaultMultiplier,
|
||||
MaxInterval: reconnectMaxInterval,
|
||||
MaxElapsedTime: 0, // retry indefinitely
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
for {
|
||||
delay := bo.NextBackOff()
|
||||
log.Infof("reconnecting conntrack listener in %s", delay)
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
c.mux.Lock()
|
||||
c.started = false
|
||||
c.mux.Unlock()
|
||||
return nil, nil
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
log.Warnf("reconnect conntrack dial: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
events := make(chan nfct.Event, defaultChannelSize)
|
||||
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
|
||||
netfilter.GroupCTNew,
|
||||
netfilter.GroupCTDestroy,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("reconnect conntrack listen: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Debugf("close conntrack connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.mux.Lock()
|
||||
if !c.started {
|
||||
// Stop() ran while we were reconnecting.
|
||||
c.mux.Unlock()
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Debugf("close conntrack connection: %v", closeErr)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
c.conn = conn
|
||||
c.mux.Unlock()
|
||||
|
||||
log.Infof("conntrack listener reconnected successfully")
|
||||
|
||||
return events, errChan
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the connection tracking. This method is idempotent.
|
||||
func (c *ConnTrack) Stop() {
|
||||
c.mux.Lock()
|
||||
@@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.started {
|
||||
select {
|
||||
case c.done <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if !c.started {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case c.done <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
c.started = false
|
||||
|
||||
var closeErr error
|
||||
if c.conn != nil {
|
||||
err := c.conn.Close()
|
||||
closeErr = c.conn.Close()
|
||||
c.conn = nil
|
||||
c.started = false
|
||||
}
|
||||
|
||||
c.RestoreAccounting()
|
||||
c.RestoreAccounting()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("close conntrack: %w", err)
|
||||
}
|
||||
if closeErr != nil {
|
||||
return fmt.Errorf("close conntrack: %w", closeErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
224
client/internal/netflow/conntrack/conntrack_test.go
Normal file
224
client/internal/netflow/conntrack/conntrack_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
nfct "github.com/ti-mo/conntrack"
|
||||
"github.com/ti-mo/netfilter"
|
||||
)
|
||||
|
||||
type mockListener struct {
|
||||
errChan chan error
|
||||
closed atomic.Bool
|
||||
closedCh chan struct{}
|
||||
}
|
||||
|
||||
func newMockListener() *mockListener {
|
||||
return &mockListener{
|
||||
errChan: make(chan error, 1),
|
||||
closedCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
|
||||
return m.errChan, nil
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
if m.closed.CompareAndSwap(false, true) {
|
||||
close(m.closedCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestReconnectAfterError(t *testing.T) {
|
||||
first := newMockListener()
|
||||
second := newMockListener()
|
||||
third := newMockListener()
|
||||
listeners := []*mockListener{first, second, third}
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
n := int(callCount.Add(1)) - 1
|
||||
return listeners[n], nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Inject an error on the first listener.
|
||||
first.errChan <- assert.AnError
|
||||
|
||||
// Wait for reconnect to complete.
|
||||
require.Eventually(t, func() bool {
|
||||
return callCount.Load() >= 2
|
||||
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
|
||||
|
||||
// The first connection must have been closed.
|
||||
select {
|
||||
case <-first.closedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("first connection was not closed")
|
||||
}
|
||||
|
||||
// Verify the receiver is still running by injecting and handling a second error.
|
||||
second.errChan <- assert.AnError
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return callCount.Load() >= 3
|
||||
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
|
||||
|
||||
ct.Stop()
|
||||
}
|
||||
|
||||
func TestStopDuringReconnectBackoff(t *testing.T) {
|
||||
mock := newMockListener()
|
||||
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
return mock, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trigger an error so the receiver enters reconnect.
|
||||
mock.errChan <- assert.AnError
|
||||
|
||||
// Wait for the error handler to close the old listener before calling Stop.
|
||||
select {
|
||||
case <-mock.closedCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for reconnect to start")
|
||||
}
|
||||
|
||||
// Stop while reconnecting.
|
||||
ct.Stop()
|
||||
|
||||
ct.mux.Lock()
|
||||
assert.False(t, ct.started, "started should be false after Stop")
|
||||
assert.Nil(t, ct.conn, "conn should be nil after Stop")
|
||||
ct.mux.Unlock()
|
||||
}
|
||||
|
||||
func TestStopRaceWithReconnectDial(t *testing.T) {
|
||||
first := newMockListener()
|
||||
dialStarted := make(chan struct{})
|
||||
dialProceed := make(chan struct{})
|
||||
second := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trigger error to enter reconnect.
|
||||
first.errChan <- assert.AnError
|
||||
|
||||
// Wait for reconnect's second dial to begin.
|
||||
select {
|
||||
case <-dialStarted:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for reconnect dial")
|
||||
}
|
||||
|
||||
// Stop while dial is in progress (conn is nil at this point).
|
||||
ct.Stop()
|
||||
|
||||
// Let the dial complete. reconnect should detect started==false and close the new conn.
|
||||
close(dialProceed)
|
||||
|
||||
// The second connection should be closed (not leaked).
|
||||
select {
|
||||
case <-second.closedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("second connection was leaked after Stop")
|
||||
}
|
||||
|
||||
ct.mux.Lock()
|
||||
assert.False(t, ct.started)
|
||||
assert.Nil(t, ct.conn)
|
||||
ct.mux.Unlock()
|
||||
}
|
||||
|
||||
func TestCloseRaceWithReconnectDial(t *testing.T) {
|
||||
first := newMockListener()
|
||||
dialStarted := make(chan struct{})
|
||||
dialProceed := make(chan struct{})
|
||||
second := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
first.errChan <- assert.AnError
|
||||
|
||||
select {
|
||||
case <-dialStarted:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for reconnect dial")
|
||||
}
|
||||
|
||||
// Close while dial is in progress (conn is nil).
|
||||
require.NoError(t, ct.Close())
|
||||
|
||||
close(dialProceed)
|
||||
|
||||
// The second connection should be closed (not leaked).
|
||||
select {
|
||||
case <-second.closedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("second connection was leaked after Close")
|
||||
}
|
||||
|
||||
ct.mux.Lock()
|
||||
assert.False(t, ct.started)
|
||||
assert.Nil(t, ct.conn)
|
||||
ct.mux.Unlock()
|
||||
}
|
||||
|
||||
func TestStartIsIdempotent(t *testing.T) {
|
||||
mock := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
callCount.Add(1)
|
||||
return mock, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second Start should be a no-op.
|
||||
err = ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
|
||||
|
||||
ct.Stop()
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
PortForwardManager *portforward.Manager
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
@@ -87,16 +89,17 @@ type ConnConfig struct {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
portForwardManager *portforward.Manager
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
|
||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
portForwardManager: services.PortForwardManager,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -181,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !forceRelay {
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -250,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
conn.workerICE.Close()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -295,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -715,50 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
defer func() {
|
||||
if status != guard.ConnStatusConnected {
|
||||
if !connected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
relayConnected := conn.workerRelay.IsRelayConnectionSupportedWithPeer() &&
|
||||
conn.statusRelay.Get() == worker.StatusConnected
|
||||
|
||||
// Force-relay mode (JS/WASM or NB_FORCE_RELAY): ICE is never used, relay is the only transport.
|
||||
if IsForceRelayed() {
|
||||
return boolToConnStatus(relayConnected)
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
iceAvailable := conn.handshaker.RemoteICESupported() && conn.workerICE != nil
|
||||
|
||||
// When ICE is not available (remote peer doesn't support it or worker not yet created),
|
||||
// relay is the only possible transport.
|
||||
if !iceAvailable {
|
||||
return boolToConnStatus(relayConnected)
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// ICE is considered "up" when it is connected or a connection attempt is in progress.
|
||||
iceConnected := conn.statusICE.Get() != worker.StatusDisconnected || conn.workerICE.InProgress()
|
||||
|
||||
// Relay is OK if the peer doesn't use relay, or if relay is actually connected.
|
||||
relayOK := !conn.workerRelay.IsRelayConnectionSupportedWithPeer() || relayConnected
|
||||
|
||||
switch {
|
||||
case iceConnected && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayConnected:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -946,10 +926,3 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func IsForceRelayed() bool {
|
||||
func isForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -8,19 +8,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
type isConnectedFunc func() bool
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -89,46 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.attempt() {
|
||||
callback()
|
||||
} else {
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -151,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// attempt processes a single ICE retry tick. It returns true if the caller should send an offer.
|
||||
// When retries are exhausted it starts the hourly ticker and returns false once to signal the caller
|
||||
// to swap the tick channel. Subsequent calls (from the hourly ticker) return true.
|
||||
func (s *iceRetryState) attempt() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
func (w *SRWatcher) Start() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -60,10 +59,6 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -71,7 +66,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
h := &Handshaker{
|
||||
return &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -81,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -109,13 +97,11 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -131,13 +117,11 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -199,18 +183,15 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -219,18 +200,3 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.IceCredentials.UFrag != "" && offer.IceCredentials.Pwd != ""
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,13 +46,9 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -61,6 +62,9 @@ type WorkerICE struct {
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
|
||||
// portForwardAttempted tracks if we've already tried port forwarding this session
|
||||
portForwardAttempted bool
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||
w.portForwardAttempted = false
|
||||
|
||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if candidate.Type() == ice.CandidateTypeServerReflexive {
|
||||
w.injectPortForwardedCandidate(candidate)
|
||||
}
|
||||
}
|
||||
|
||||
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
|
||||
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
|
||||
pfManager := w.conn.portForwardManager
|
||||
if pfManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mapping := pfManager.GetMapping()
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
if w.portForwardAttempted {
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.portForwardAttempted = true
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
|
||||
if err != nil {
|
||||
w.log.Warnf("create forwarded candidate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
|
||||
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
|
||||
|
||||
go func() {
|
||||
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
|
||||
w.log.Errorf("signal port-forwarded candidate: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
|
||||
// It uses the NAT gateway's external IP with the forwarded port.
|
||||
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
|
||||
var externalIP string
|
||||
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
|
||||
externalIP = mapping.ExternalIP.String()
|
||||
} else {
|
||||
// Fallback to STUN-discovered address if NAT didn't provide external IP
|
||||
externalIP = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Per RFC 8445, the related address for srflx is the base (host candidate address).
|
||||
// If the original srflx has unspecified related address, use its own address as base.
|
||||
relAddr := srflxCandidate.RelatedAddress().Address
|
||||
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
|
||||
relAddr = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
|
||||
// over regular srflx during ICE connectivity checks.
|
||||
priority := srflxCandidate.Priority() + 1000
|
||||
|
||||
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: srflxCandidate.NetworkType().String(),
|
||||
Address: externalIP,
|
||||
Port: int(mapping.ExternalPort),
|
||||
Component: srflxCandidate.Component(),
|
||||
Priority: priority,
|
||||
RelAddr: relAddr,
|
||||
RelPort: int(mapping.InternalPort),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create candidate: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range srflxCandidate.Extensions() {
|
||||
if e.Key == ice.ExtensionKeyCandidateID {
|
||||
e.Value = srflxCandidate.ID()
|
||||
}
|
||||
if err := candidate.AddExtension(e); err != nil {
|
||||
return nil, fmt.Errorf("add extension: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return candidate, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||
if !lok || !rok {
|
||||
continue
|
||||
}
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
|
||||
sessionID,
|
||||
local.NetworkType(), local.Type(), local.Address(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||
local.NetworkType(), local.Type(), local.Address(), local.Port(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
|
||||
stat.CurrentRoundTripTime*1000)
|
||||
}
|
||||
}
|
||||
|
||||
35
client/internal/portforward/env.go
Normal file
35
client/internal/portforward/env.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
||||
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
|
||||
)
|
||||
|
||||
func isDisabledByEnv() bool {
|
||||
return parseBoolEnv(envDisableNATMapper)
|
||||
}
|
||||
|
||||
func isHealthCheckDisabled() bool {
|
||||
return parseBoolEnv(envDisablePCPHealthCheck)
|
||||
}
|
||||
|
||||
func parseBoolEnv(key string) bool {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", key, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
342
client/internal/portforward/manager.go
Normal file
342
client/internal/portforward/manager.go
Normal file
@@ -0,0 +1,342 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-nat"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMappingTTL = 2 * time.Hour
|
||||
healthCheckInterval = 1 * time.Minute
|
||||
discoveryTimeout = 10 * time.Second
|
||||
mappingDescription = "NetBird"
|
||||
)
|
||||
|
||||
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
|
||||
// allowing for whitespace/newlines between tags from different router firmware.
|
||||
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
|
||||
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
|
||||
// which blocks startup for ~10s when no gateway is present (affects all clients).
|
||||
|
||||
type Manager struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
mapping *Mapping
|
||||
mappingLock sync.Mutex
|
||||
|
||||
wgPort uint16
|
||||
|
||||
done chan struct{}
|
||||
stopCtx chan context.Context
|
||||
|
||||
// protect exported functions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new port forwarding manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
stopCtx: make(chan context.Context, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
|
||||
m.mu.Lock()
|
||||
if m.cancel != nil {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if isDisabledByEnv() {
|
||||
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if wgPort == 0 {
|
||||
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.wgPort = wgPort
|
||||
|
||||
m.done = make(chan struct{})
|
||||
defer close(m.done)
|
||||
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
m.mu.Unlock()
|
||||
|
||||
gateway, mapping, err := m.setup(ctx)
|
||||
if err != nil {
|
||||
log.Infof("port forwarding setup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mappingLock.Lock()
|
||||
m.mapping = mapping
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
m.renewLoop(ctx, gateway, mapping.TTL)
|
||||
|
||||
select {
|
||||
case cleanupCtx := <-m.stopCtx:
|
||||
// block the Start while cleaned up gracefully
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
default:
|
||||
// return Start immediately and cleanup in background
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
go func() {
|
||||
defer cleanupCancel()
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// GetMapping returns the current mapping if ready, nil otherwise
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
m.mappingLock.Lock()
|
||||
defer m.mappingLock.Unlock()
|
||||
|
||||
if m.mapping == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapping := *m.mapping
|
||||
return &mapping
|
||||
}
|
||||
|
||||
// GracefullyStop cancels the manager and attempts to delete the port mapping.
|
||||
// After GracefullyStop returns, the manager cannot be restarted.
|
||||
func (m *Manager) GracefullyStop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
|
||||
m.startTearDown(ctx)
|
||||
|
||||
m.cancel()
|
||||
m.cancel = nil
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
||||
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
||||
defer discoverCancel()
|
||||
|
||||
gateway, err := discoverGateway(discoverCtx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("discovered NAT gateway: %s", gateway.Type())
|
||||
|
||||
mapping, err := m.createMapping(ctx, gateway)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create port mapping: %w", err)
|
||||
}
|
||||
return gateway, mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ttl := defaultMappingTTL
|
||||
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
if !isPermanentLeaseRequired(err) {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
|
||||
ttl = 0
|
||||
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
externalIP, err := gateway.GetExternalAddress()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get external address: %v", err)
|
||||
}
|
||||
|
||||
mapping := &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: m.wgPort,
|
||||
ExternalPort: uint16(externalPort),
|
||||
ExternalIP: externalIP,
|
||||
NATType: gateway.Type(),
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
|
||||
m.wgPort, externalPort, gateway.Type(), externalIP)
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
||||
if ttl == 0 {
|
||||
// Permanent mappings don't expire, just wait for cancellation
|
||||
// but still run health checks for PCP gateways.
|
||||
m.permanentLeaseLoop(ctx, gateway)
|
||||
return
|
||||
}
|
||||
|
||||
renewTicker := time.NewTicker(ttl / 2)
|
||||
healthTicker := time.NewTicker(healthCheckInterval)
|
||||
defer renewTicker.Stop()
|
||||
defer healthTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-renewTicker.C:
|
||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||
log.Warnf("failed to renew port mapping: %v", err)
|
||||
continue
|
||||
}
|
||||
case <-healthTicker.C:
|
||||
if m.checkHealthAndRecreate(ctx, gateway) {
|
||||
renewTicker.Reset(ttl / 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
|
||||
healthTicker := time.NewTicker(healthCheckInterval)
|
||||
defer healthTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-healthTicker.C:
|
||||
m.checkHealthAndRecreate(ctx, gateway)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
|
||||
if isHealthCheckDisabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mappingLock.Lock()
|
||||
hasMapping := m.mapping != nil
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
if !hasMapping {
|
||||
return false
|
||||
}
|
||||
|
||||
pcpNAT, ok := gateway.(*pcp.NAT)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
|
||||
if err != nil {
|
||||
log.Debugf("PCP health check failed: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if serverRestarted {
|
||||
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
|
||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||
log.Errorf("failed to recreate port mapping after server restart: %v", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add port mapping: %w", err)
|
||||
}
|
||||
|
||||
if uint16(externalPort) != m.mapping.ExternalPort {
|
||||
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
|
||||
m.mappingLock.Lock()
|
||||
m.mapping.ExternalPort = uint16(externalPort)
|
||||
m.mappingLock.Unlock()
|
||||
}
|
||||
|
||||
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
|
||||
m.mappingLock.Lock()
|
||||
mapping := m.mapping
|
||||
m.mapping = nil
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
|
||||
log.Warnf("delete port mapping on stop: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
|
||||
}
|
||||
|
||||
func (m *Manager) startTearDown(ctx context.Context) {
|
||||
select {
|
||||
case m.stopCtx <- ctx:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
|
||||
func isPermanentLeaseRequired(err error) bool {
|
||||
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
|
||||
}
|
||||
39
client/internal/portforward/manager_js.go
Normal file
39
client/internal/portforward/manager_js.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a stub manager for js/wasm builds.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
|
||||
func (m *Manager) Start(context.Context, uint16) {
|
||||
// no NAT traversal in wasm
|
||||
}
|
||||
|
||||
// GracefullyStop is a no-op on js/wasm.
|
||||
func (m *Manager) GracefullyStop(context.Context) error { return nil }
|
||||
|
||||
// GetMapping always returns nil on js/wasm.
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
return nil
|
||||
}
|
||||
201
client/internal/portforward/manager_test.go
Normal file
201
client/internal/portforward/manager_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockNAT struct {
|
||||
natType string
|
||||
deviceAddr net.IP
|
||||
externalAddr net.IP
|
||||
internalAddr net.IP
|
||||
mappings map[int]int
|
||||
addMappingErr error
|
||||
deleteMappingErr error
|
||||
onlyPermanentLeases bool
|
||||
lastTimeout time.Duration
|
||||
}
|
||||
|
||||
func newMockNAT() *mockNAT {
|
||||
return &mockNAT{
|
||||
natType: "Mock-NAT",
|
||||
deviceAddr: net.ParseIP("192.168.1.1"),
|
||||
externalAddr: net.ParseIP("203.0.113.50"),
|
||||
internalAddr: net.ParseIP("192.168.1.100"),
|
||||
mappings: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNAT) Type() string {
|
||||
return m.natType
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
|
||||
return m.deviceAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
|
||||
return m.externalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
|
||||
return m.internalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
|
||||
if m.addMappingErr != nil {
|
||||
return 0, m.addMappingErr
|
||||
}
|
||||
if m.onlyPermanentLeases && timeout != 0 {
|
||||
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
|
||||
}
|
||||
externalPort := internalPort
|
||||
m.mappings[internalPort] = externalPort
|
||||
m.lastTimeout = timeout
|
||||
return externalPort, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||
if m.deleteMappingErr != nil {
|
||||
return m.deleteMappingErr
|
||||
}
|
||||
delete(m.mappings, internalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManager_CreateMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, "udp", mapping.Protocol)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, uint16(51820), mapping.ExternalPort)
|
||||
assert.Equal(t, "Mock-NAT", mapping.NATType)
|
||||
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
|
||||
assert.Equal(t, defaultMappingTTL, mapping.TTL)
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
|
||||
m := NewManager()
|
||||
assert.Nil(t, m.GetMapping())
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
mapping := m.GetMapping()
|
||||
require.NotNil(t, mapping)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
|
||||
// Mutating the returned copy should not affect the manager's mapping.
|
||||
mapping.ExternalPort = 9999
|
||||
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
gateway := newMockNAT()
|
||||
// Seed the mock so we can verify deletion.
|
||||
gateway.mappings[51820] = 51820
|
||||
|
||||
m.cleanup(context.Background(), gateway)
|
||||
|
||||
_, exists := gateway.mappings[51820]
|
||||
assert.False(t, exists, "mapping should be deleted from gateway")
|
||||
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_NilMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
gateway := newMockNAT()
|
||||
|
||||
// Should not panic or call gateway.
|
||||
m.cleanup(context.Background(), gateway)
|
||||
}
|
||||
|
||||
|
||||
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
gateway.onlyPermanentLeases = true
|
||||
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
|
||||
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
|
||||
}
|
||||
|
||||
func TestIsPermanentLeaseRequired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UPnP error 725",
|
||||
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped error with 725",
|
||||
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error 725 with newlines in XML",
|
||||
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "bare 725 without XML tag",
|
||||
err: fmt.Errorf("error code 725"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unrelated error",
|
||||
err: fmt.Errorf("connection refused"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
408
client/internal/portforward/pcp/client.go
Normal file
408
client/internal/portforward/pcp/client.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package pcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 3 * time.Second
|
||||
responseBufferSize = 128
|
||||
|
||||
// RFC 6887 Section 8.1.1 retry timing
|
||||
initialRetryDelay = 3 * time.Second
|
||||
maxRetryDelay = 1024 * time.Second
|
||||
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
|
||||
)
|
||||
|
||||
// Client is a PCP protocol client.
|
||||
// All methods are safe for concurrent use.
|
||||
type Client struct {
|
||||
gateway netip.Addr
|
||||
timeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
// localIP caches the resolved local IP address.
|
||||
localIP netip.Addr
|
||||
// lastEpoch is the last observed server epoch value.
|
||||
lastEpoch uint32
|
||||
// epochTime tracks when lastEpoch was received for state loss detection.
|
||||
epochTime time.Time
|
||||
// externalIP caches the external IP from the last successful MAP response.
|
||||
externalIP netip.Addr
|
||||
// epochStateLost is set when epoch indicates server restart.
|
||||
epochStateLost bool
|
||||
}
|
||||
|
||||
// NewClient creates a new PCP client for the gateway at the given IP.
|
||||
func NewClient(gateway net.IP) *Client {
|
||||
addr, ok := netip.AddrFromSlice(gateway)
|
||||
if !ok {
|
||||
log.Debugf("invalid gateway IP: %v", gateway)
|
||||
}
|
||||
return &Client{
|
||||
gateway: addr.Unmap(),
|
||||
timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClientWithTimeout creates a new PCP client with a custom timeout.
|
||||
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
|
||||
addr, ok := netip.AddrFromSlice(gateway)
|
||||
if !ok {
|
||||
log.Debugf("invalid gateway IP: %v", gateway)
|
||||
}
|
||||
return &Client{
|
||||
gateway: addr.Unmap(),
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// SetLocalIP sets the local IP address to use in PCP requests.
|
||||
func (c *Client) SetLocalIP(ip net.IP) {
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Debugf("invalid local IP: %v", ip)
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.localIP = addr.Unmap()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Gateway returns the gateway IP address.
|
||||
func (c *Client) Gateway() net.IP {
|
||||
return c.gateway.AsSlice()
|
||||
}
|
||||
|
||||
// Announce sends a PCP ANNOUNCE request to discover PCP support.
|
||||
// Returns the server's epoch time on success.
|
||||
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
|
||||
localIP, err := c.getLocalIP()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get local IP: %w", err)
|
||||
}
|
||||
|
||||
req := buildAnnounceRequest(localIP)
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("send announce: %w", err)
|
||||
}
|
||||
|
||||
parsed, err := parseResponse(resp)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse announce response: %w", err)
|
||||
}
|
||||
|
||||
if parsed.ResultCode != ResultSuccess {
|
||||
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.updateEpochLocked(parsed.Epoch) {
|
||||
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return parsed.Epoch, nil
|
||||
}
|
||||
|
||||
// AddPortMapping requests a port mapping from the PCP server.
|
||||
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
|
||||
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
|
||||
}
|
||||
|
||||
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
|
||||
// Use lifetime <= 0 to delete a mapping.
|
||||
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
|
||||
var extIP netip.Addr
|
||||
if suggestedExtIP != nil {
|
||||
var ok bool
|
||||
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
|
||||
if !ok {
|
||||
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
|
||||
}
|
||||
extIP = extIP.Unmap()
|
||||
}
|
||||
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
|
||||
}
|
||||
|
||||
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
|
||||
localIP, err := c.getLocalIP()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get local IP: %w", err)
|
||||
}
|
||||
|
||||
proto, err := protocolNumber(protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse protocol: %w", err)
|
||||
}
|
||||
|
||||
var nonce [12]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
|
||||
// default for positive durations that round to 0 seconds.
|
||||
var lifetimeSec uint32
|
||||
if lifetime > 0 {
|
||||
lifetimeSec = uint32(lifetime.Seconds())
|
||||
if lifetimeSec == 0 {
|
||||
lifetimeSec = DefaultLifetime
|
||||
}
|
||||
}
|
||||
|
||||
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send map request: %w", err)
|
||||
}
|
||||
|
||||
mapResp, err := parseMapResponse(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse map response: %w", err)
|
||||
}
|
||||
|
||||
if mapResp.Nonce != nonce {
|
||||
return nil, fmt.Errorf("nonce mismatch in response")
|
||||
}
|
||||
|
||||
if mapResp.Protocol != proto {
|
||||
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
|
||||
}
|
||||
if mapResp.InternalPort != uint16(internalPort) {
|
||||
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
|
||||
}
|
||||
|
||||
if mapResp.ResultCode != ResultSuccess {
|
||||
return nil, &Error{
|
||||
Code: mapResp.ResultCode,
|
||||
Message: ResultCodeString(mapResp.ResultCode),
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.updateEpochLocked(mapResp.Epoch) {
|
||||
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
||||
}
|
||||
c.cacheExternalIPLocked(mapResp.ExternalIP)
|
||||
c.mu.Unlock()
|
||||
return mapResp, nil
|
||||
}
|
||||
|
||||
// DeletePortMapping removes a port mapping by requesting zero lifetime.
|
||||
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
|
||||
var pcpErr *Error
|
||||
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete mapping: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExternalAddress returns the external IP address.
|
||||
// First checks for a cached value from previous MAP responses.
|
||||
// If not cached, creates a short-lived mapping to discover the external IP.
|
||||
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
|
||||
c.mu.Lock()
|
||||
if c.externalIP.IsValid() {
|
||||
ip := c.externalIP.AsSlice()
|
||||
c.mu.Unlock()
|
||||
return ip, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
// Use an ephemeral port in the dynamic range (49152-65535).
|
||||
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
|
||||
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
|
||||
|
||||
// Use minimal lifetime (1 second) for discovery.
|
||||
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create temporary mapping: %w", err)
|
||||
}
|
||||
|
||||
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
|
||||
log.Debugf("cleanup temporary PCP mapping: %v", err)
|
||||
}
|
||||
|
||||
return resp.ExternalIP.AsSlice(), nil
|
||||
}
|
||||
|
||||
// LastEpoch returns the last observed server epoch value.
|
||||
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
|
||||
func (c *Client) LastEpoch() uint32 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.lastEpoch
|
||||
}
|
||||
|
||||
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
|
||||
func (c *Client) EpochStateLost() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
lost := c.epochStateLost
|
||||
c.epochStateLost = false
|
||||
return lost
|
||||
}
|
||||
|
||||
// updateEpoch updates the epoch tracking and detects potential state loss.
|
||||
// Returns true if state loss was detected (server likely restarted).
|
||||
// Caller must hold c.mu.
|
||||
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
|
||||
now := time.Now()
|
||||
stateLost := false
|
||||
|
||||
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
|
||||
// client_delta = time since last response
|
||||
// server_delta = epoch change since last response
|
||||
// Invalid if: client_delta+2 < server_delta - server_delta/16
|
||||
// OR: server_delta+2 < client_delta - client_delta/16
|
||||
// The +2 handles quantization, /16 (6.25%) handles clock drift.
|
||||
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
|
||||
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
|
||||
serverDelta := newEpoch - c.lastEpoch
|
||||
|
||||
// Check for epoch going backwards or jumping unexpectedly.
|
||||
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
|
||||
if clientDelta+2 < serverDelta-(serverDelta/16) ||
|
||||
serverDelta+2 < clientDelta-(clientDelta/16) {
|
||||
stateLost = true
|
||||
c.epochStateLost = true
|
||||
}
|
||||
}
|
||||
|
||||
c.lastEpoch = newEpoch
|
||||
c.epochTime = now
|
||||
return stateLost
|
||||
}
|
||||
|
||||
// cacheExternalIP stores the external IP from a successful MAP response.
|
||||
// Caller must hold c.mu.
|
||||
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
|
||||
if ip.IsValid() && !ip.IsUnspecified() {
|
||||
c.externalIP = ip
|
||||
}
|
||||
}
|
||||
|
||||
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
|
||||
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
|
||||
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
|
||||
|
||||
var lastErr error
|
||||
delay := initialRetryDelay
|
||||
|
||||
for range maxRetries {
|
||||
resp, err := c.sendOnce(ctx, addr, req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
|
||||
// RAND is random between -0.1 and +0.1
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(retryDelayWithJitter(delay)):
|
||||
}
|
||||
delay = min(delay*2, maxRetryDelay)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
|
||||
func retryDelayWithJitter(d time.Duration) time.Duration {
|
||||
var b [1]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
|
||||
jitter := (float64(b[0])/255.0)*0.2 - 0.1
|
||||
return time.Duration(float64(d) * (1 + jitter))
|
||||
}
|
||||
|
||||
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
|
||||
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
|
||||
conn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("close UDP connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
timeout := c.timeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if remaining := time.Until(deadline); remaining < timeout {
|
||||
timeout = remaining
|
||||
}
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.WriteToUDP(req, addr); err != nil {
|
||||
return nil, fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
resp := make([]byte, responseBufferSize)
|
||||
n, from, err := conn.ReadFromUDP(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read: %w", err)
|
||||
}
|
||||
|
||||
// RFC 6887 §8.3: Validate response came from expected PCP server.
|
||||
if !from.IP.Equal(addr.IP) {
|
||||
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
|
||||
}
|
||||
|
||||
return resp[:n], nil
|
||||
}
|
||||
|
||||
func (c *Client) getLocalIP() (netip.Addr, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.localIP.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
|
||||
}
|
||||
return c.localIP, nil
|
||||
}
|
||||
|
||||
func protocolNumber(protocol string) (uint8, error) {
|
||||
switch protocol {
|
||||
case "udp", "UDP":
|
||||
return ProtoUDP, nil
|
||||
case "tcp", "TCP":
|
||||
return ProtoTCP, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// Error represents a PCP error response.
|
||||
type Error struct {
|
||||
Code uint8
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
|
||||
}
|
||||
187
client/internal/portforward/pcp/client_test.go
Normal file
187
client/internal/portforward/pcp/client_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package pcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddrConversion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr netip.Addr
|
||||
}{
|
||||
{"IPv4", netip.MustParseAddr("192.168.1.100")},
|
||||
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
|
||||
{"IPv6", netip.MustParseAddr("2001:db8::1")},
|
||||
{"IPv6 loopback", netip.MustParseAddr("::1")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b16 := addrTo16(tt.addr)
|
||||
|
||||
recovered := addrFrom16(b16)
|
||||
assert.Equal(t, tt.addr, recovered, "address should round-trip")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnnounceRequest(t *testing.T) {
|
||||
clientIP := netip.MustParseAddr("192.168.1.100")
|
||||
req := buildAnnounceRequest(clientIP)
|
||||
|
||||
require.Len(t, req, headerSize)
|
||||
assert.Equal(t, byte(Version), req[0], "version")
|
||||
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
|
||||
|
||||
// Check client IP is properly encoded as IPv4-mapped IPv6
|
||||
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
|
||||
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
|
||||
assert.Equal(t, byte(192), req[20], "IP octet 1")
|
||||
assert.Equal(t, byte(168), req[21], "IP octet 2")
|
||||
assert.Equal(t, byte(1), req[22], "IP octet 3")
|
||||
assert.Equal(t, byte(100), req[23], "IP octet 4")
|
||||
}
|
||||
|
||||
func TestBuildMapRequest(t *testing.T) {
|
||||
clientIP := netip.MustParseAddr("192.168.1.100")
|
||||
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
|
||||
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
|
||||
|
||||
require.Len(t, req, mapRequestSize)
|
||||
assert.Equal(t, byte(Version), req[0], "version")
|
||||
assert.Equal(t, byte(OpMap), req[1], "opcode")
|
||||
|
||||
// Lifetime at bytes 4-7
|
||||
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
|
||||
|
||||
// Nonce at bytes 24-35
|
||||
assert.Equal(t, nonce[:], req[24:36], "nonce")
|
||||
|
||||
// Protocol at byte 36
|
||||
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
|
||||
|
||||
// Internal port at bytes 40-41
|
||||
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
|
||||
|
||||
// External port at bytes 42-43
|
||||
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
|
||||
}
|
||||
|
||||
func TestParseResponse(t *testing.T) {
|
||||
// Construct a valid ANNOUNCE response
|
||||
resp := make([]byte, headerSize)
|
||||
resp[0] = Version
|
||||
resp[1] = OpAnnounce | OpReply
|
||||
// Result code = 0 (success)
|
||||
// Lifetime = 0
|
||||
// Epoch = 12345
|
||||
resp[8] = 0
|
||||
resp[9] = 0
|
||||
resp[10] = 0x30
|
||||
resp[11] = 0x39
|
||||
|
||||
parsed, err := parseResponse(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(Version), parsed.Version)
|
||||
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
|
||||
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
|
||||
assert.Equal(t, uint32(12345), parsed.Epoch)
|
||||
}
|
||||
|
||||
func TestParseResponseErrors(t *testing.T) {
|
||||
t.Run("too short", func(t *testing.T) {
|
||||
_, err := parseResponse([]byte{1, 2, 3})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("wrong version", func(t *testing.T) {
|
||||
resp := make([]byte, headerSize)
|
||||
resp[0] = 1 // Wrong version
|
||||
resp[1] = OpReply
|
||||
_, err := parseResponse(resp)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing reply bit", func(t *testing.T) {
|
||||
resp := make([]byte, headerSize)
|
||||
resp[0] = Version
|
||||
resp[1] = OpAnnounce // Missing OpReply bit
|
||||
_, err := parseResponse(resp)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResultCodeString(t *testing.T) {
|
||||
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
|
||||
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
|
||||
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
|
||||
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
|
||||
}
|
||||
|
||||
func TestProtocolNumber(t *testing.T) {
|
||||
proto, err := protocolNumber("udp")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(ProtoUDP), proto)
|
||||
|
||||
proto, err = protocolNumber("tcp")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(ProtoTCP), proto)
|
||||
|
||||
proto, err = protocolNumber("UDP")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(ProtoUDP), proto)
|
||||
|
||||
_, err = protocolNumber("icmp")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestClientCreation(t *testing.T) {
|
||||
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
|
||||
|
||||
client := NewClient(gateway)
|
||||
assert.Equal(t, net.IP(gateway), client.Gateway())
|
||||
assert.Equal(t, defaultTimeout, client.timeout)
|
||||
|
||||
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
|
||||
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
|
||||
}
|
||||
|
||||
func TestNATType(t *testing.T) {
|
||||
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
|
||||
assert.Equal(t, "PCP", n.Type())
|
||||
}
|
||||
|
||||
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
|
||||
func TestClientIntegration(t *testing.T) {
|
||||
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
|
||||
|
||||
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
|
||||
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
|
||||
|
||||
client := NewClient(gateway)
|
||||
client.SetLocalIP(localIP)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Test ANNOUNCE
|
||||
epoch, err := client.Announce(ctx)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Server epoch: %d", epoch)
|
||||
|
||||
// Test MAP
|
||||
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
|
||||
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
|
||||
|
||||
// Cleanup
|
||||
err = client.DeletePortMapping(ctx, "udp", 51820)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
209
client/internal/portforward/pcp/nat.go
Normal file
209
client/internal/portforward/pcp/nat.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package pcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/libp2p/go-nat"
|
||||
"github.com/libp2p/go-netroute"
|
||||
)
|
||||
|
||||
var _ nat.NAT = (*NAT)(nil)
|
||||
|
||||
// NAT implements the go-nat NAT interface using PCP.
|
||||
// Supports dual-stack (IPv4 and IPv6) when available.
|
||||
// All methods are safe for concurrent use.
|
||||
//
|
||||
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
|
||||
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
|
||||
// and needs to be recreated with the new address.
|
||||
type NAT struct {
|
||||
client *Client
|
||||
|
||||
mu sync.RWMutex
|
||||
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
|
||||
client6 *Client
|
||||
// localIP6 caches the local IPv6 address used for PCP requests.
|
||||
localIP6 netip.Addr
|
||||
}
|
||||
|
||||
// NewNAT creates a new NAT instance backed by PCP.
|
||||
func NewNAT(gateway, localIP net.IP) *NAT {
|
||||
client := NewClient(gateway)
|
||||
client.SetLocalIP(localIP)
|
||||
return &NAT{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Type returns "PCP" as the NAT type.
|
||||
func (n *NAT) Type() string {
|
||||
return "PCP"
|
||||
}
|
||||
|
||||
// GetDeviceAddress returns the gateway IP address.
|
||||
func (n *NAT) GetDeviceAddress() (net.IP, error) {
|
||||
return n.client.Gateway(), nil
|
||||
}
|
||||
|
||||
// GetExternalAddress returns the external IP address.
|
||||
func (n *NAT) GetExternalAddress() (net.IP, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
return n.client.GetExternalAddress(ctx)
|
||||
}
|
||||
|
||||
// GetInternalAddress returns the local IP address used to communicate with the gateway.
|
||||
func (n *NAT) GetInternalAddress() (net.IP, error) {
|
||||
addr, err := n.client.getLocalIP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return addr.AsSlice(), nil
|
||||
}
|
||||
|
||||
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
|
||||
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
|
||||
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("add mapping: %w", err)
|
||||
}
|
||||
|
||||
n.mu.RLock()
|
||||
client6 := n.client6
|
||||
localIP6 := n.localIP6
|
||||
n.mu.RUnlock()
|
||||
|
||||
if client6 == nil {
|
||||
return int(resp.ExternalPort), nil
|
||||
}
|
||||
|
||||
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
|
||||
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
|
||||
return int(resp.ExternalPort), nil
|
||||
}
|
||||
|
||||
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
|
||||
return int(resp.ExternalPort), nil
|
||||
}
|
||||
|
||||
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
|
||||
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
|
||||
|
||||
n.mu.RLock()
|
||||
client6 := n.client6
|
||||
n.mu.RUnlock()
|
||||
|
||||
if client6 != nil {
|
||||
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
|
||||
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete mapping: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
|
||||
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
|
||||
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
|
||||
epoch, err = n.client.Announce(ctx)
|
||||
if err != nil {
|
||||
return 0, false, fmt.Errorf("announce: %w", err)
|
||||
}
|
||||
return epoch, n.client.EpochStateLost(), nil
|
||||
}
|
||||
|
||||
// DiscoverPCP attempts to discover a PCP-capable gateway.
|
||||
// Returns a NAT interface if PCP is supported, or an error otherwise.
|
||||
// Discovers both IPv4 and IPv6 gateways when available.
|
||||
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
|
||||
gateway, localIP, err := getDefaultGateway()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get default gateway: %w", err)
|
||||
}
|
||||
|
||||
client := NewClient(gateway)
|
||||
client.SetLocalIP(localIP)
|
||||
if _, err := client.Announce(ctx); err != nil {
|
||||
return nil, fmt.Errorf("PCP announce: %w", err)
|
||||
}
|
||||
|
||||
result := &NAT{client: client}
|
||||
discoverIPv6(ctx, result)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func discoverIPv6(ctx context.Context, result *NAT) {
|
||||
gateway6, localIP6, err := getDefaultGateway6()
|
||||
if err != nil {
|
||||
log.Debugf("IPv6 gateway discovery failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
client6 := NewClient(gateway6)
|
||||
client6.SetLocalIP(localIP6)
|
||||
if _, err := client6.Announce(ctx); err != nil {
|
||||
log.Debugf("PCP IPv6 announce failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
addr, ok := netip.AddrFromSlice(localIP6)
|
||||
if !ok {
|
||||
log.Debugf("invalid IPv6 local IP: %v", localIP6)
|
||||
return
|
||||
}
|
||||
result.mu.Lock()
|
||||
result.client6 = client6
|
||||
result.localIP6 = addr
|
||||
result.mu.Unlock()
|
||||
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
|
||||
}
|
||||
|
||||
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
|
||||
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||
router, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_, gateway, localIP, err = router.Route(net.IPv4zero)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if gateway == nil {
|
||||
return nil, nil, nat.ErrNoNATFound
|
||||
}
|
||||
|
||||
return gateway, localIP, nil
|
||||
}
|
||||
|
||||
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
|
||||
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||
router, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_, gateway, localIP, err = router.Route(net.IPv6zero)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if gateway == nil {
|
||||
return nil, nil, nat.ErrNoNATFound
|
||||
}
|
||||
|
||||
return gateway, localIP, nil
|
||||
}
|
||||
225
client/internal/portforward/pcp/protocol.go
Normal file
225
client/internal/portforward/pcp/protocol.go
Normal file
@@ -0,0 +1,225 @@
|
||||
// Package pcp implements the Port Control Protocol (RFC 6887).
|
||||
//
|
||||
// # Implemented Features
|
||||
//
|
||||
// - ANNOUNCE opcode: Discovers PCP server support
|
||||
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
|
||||
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
|
||||
// - Nonce validation: Prevents response spoofing
|
||||
// - Epoch tracking: Detects server restarts per Section 8.5
|
||||
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
|
||||
//
|
||||
// # Not Implemented
|
||||
//
|
||||
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
|
||||
// - THIRD_PARTY option: For managing mappings on behalf of other devices
|
||||
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
|
||||
// - FILTER option: To restrict remote peer addresses
|
||||
//
|
||||
// These optional features are omitted because the primary use case is simple
|
||||
// port forwarding for WireGuard, which only requires MAP with default behavior.
|
||||
package pcp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const (
|
||||
// Version is the PCP protocol version (RFC 6887).
|
||||
Version = 2
|
||||
|
||||
// Port is the standard PCP server port.
|
||||
Port = 5351
|
||||
|
||||
// DefaultLifetime is the default requested mapping lifetime in seconds.
|
||||
DefaultLifetime = 7200 // 2 hours
|
||||
|
||||
// Header sizes
|
||||
headerSize = 24
|
||||
mapPayloadSize = 36
|
||||
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
|
||||
)
|
||||
|
||||
// Opcodes
|
||||
const (
|
||||
OpAnnounce = 0
|
||||
OpMap = 1
|
||||
OpPeer = 2
|
||||
OpReply = 0x80 // OR'd with opcode in responses
|
||||
)
|
||||
|
||||
// Protocol numbers for MAP requests
|
||||
const (
|
||||
ProtoUDP = 17
|
||||
ProtoTCP = 6
|
||||
)
|
||||
|
||||
// Result codes (RFC 6887 Section 7.4)
|
||||
const (
|
||||
ResultSuccess = 0
|
||||
ResultUnsuppVersion = 1
|
||||
ResultNotAuthorized = 2
|
||||
ResultMalformedRequest = 3
|
||||
ResultUnsuppOpcode = 4
|
||||
ResultUnsuppOption = 5
|
||||
ResultMalformedOption = 6
|
||||
ResultNetworkFailure = 7
|
||||
ResultNoResources = 8
|
||||
ResultUnsuppProtocol = 9
|
||||
ResultUserExQuota = 10
|
||||
ResultCannotProvideExt = 11
|
||||
ResultAddressMismatch = 12
|
||||
ResultExcessiveRemotePeers = 13
|
||||
)
|
||||
|
||||
// ResultCodeString returns a human-readable string for a result code.
|
||||
func ResultCodeString(code uint8) string {
|
||||
switch code {
|
||||
case ResultSuccess:
|
||||
return "SUCCESS"
|
||||
case ResultUnsuppVersion:
|
||||
return "UNSUPP_VERSION"
|
||||
case ResultNotAuthorized:
|
||||
return "NOT_AUTHORIZED"
|
||||
case ResultMalformedRequest:
|
||||
return "MALFORMED_REQUEST"
|
||||
case ResultUnsuppOpcode:
|
||||
return "UNSUPP_OPCODE"
|
||||
case ResultUnsuppOption:
|
||||
return "UNSUPP_OPTION"
|
||||
case ResultMalformedOption:
|
||||
return "MALFORMED_OPTION"
|
||||
case ResultNetworkFailure:
|
||||
return "NETWORK_FAILURE"
|
||||
case ResultNoResources:
|
||||
return "NO_RESOURCES"
|
||||
case ResultUnsuppProtocol:
|
||||
return "UNSUPP_PROTOCOL"
|
||||
case ResultUserExQuota:
|
||||
return "USER_EX_QUOTA"
|
||||
case ResultCannotProvideExt:
|
||||
return "CANNOT_PROVIDE_EXTERNAL"
|
||||
case ResultAddressMismatch:
|
||||
return "ADDRESS_MISMATCH"
|
||||
case ResultExcessiveRemotePeers:
|
||||
return "EXCESSIVE_REMOTE_PEERS"
|
||||
default:
|
||||
return fmt.Sprintf("UNKNOWN(%d)", code)
|
||||
}
|
||||
}
|
||||
|
||||
// Response represents a parsed PCP response header.
|
||||
type Response struct {
|
||||
Version uint8
|
||||
Opcode uint8
|
||||
ResultCode uint8
|
||||
Lifetime uint32
|
||||
Epoch uint32
|
||||
}
|
||||
|
||||
// MapResponse contains the full response to a MAP request.
|
||||
type MapResponse struct {
|
||||
Response
|
||||
Nonce [12]byte
|
||||
Protocol uint8
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP netip.Addr
|
||||
}
|
||||
|
||||
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
|
||||
func addrTo16(addr netip.Addr) [16]byte {
|
||||
if addr.Is4() {
|
||||
return netip.AddrFrom4(addr.As4()).As16()
|
||||
}
|
||||
return addr.As16()
|
||||
}
|
||||
|
||||
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
|
||||
func addrFrom16(b [16]byte) netip.Addr {
|
||||
return netip.AddrFrom16(b).Unmap()
|
||||
}
|
||||
|
||||
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
|
||||
func buildAnnounceRequest(clientIP netip.Addr) []byte {
|
||||
req := make([]byte, headerSize)
|
||||
req[0] = Version
|
||||
req[1] = OpAnnounce
|
||||
mapped := addrTo16(clientIP)
|
||||
copy(req[8:24], mapped[:])
|
||||
return req
|
||||
}
|
||||
|
||||
// buildMapRequest creates a PCP MAP request packet.
|
||||
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
|
||||
req := make([]byte, mapRequestSize)
|
||||
|
||||
// Header
|
||||
req[0] = Version
|
||||
req[1] = OpMap
|
||||
binary.BigEndian.PutUint32(req[4:8], lifetime)
|
||||
mapped := addrTo16(clientIP)
|
||||
copy(req[8:24], mapped[:])
|
||||
|
||||
// MAP payload
|
||||
copy(req[24:36], nonce[:])
|
||||
req[36] = protocol
|
||||
binary.BigEndian.PutUint16(req[40:42], internalPort)
|
||||
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
|
||||
if suggestedExtIP.IsValid() {
|
||||
extMapped := addrTo16(suggestedExtIP)
|
||||
copy(req[44:60], extMapped[:])
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
// parseResponse parses the common PCP response header.
|
||||
func parseResponse(data []byte) (*Response, error) {
|
||||
if len(data) < headerSize {
|
||||
return nil, fmt.Errorf("response too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
resp := &Response{
|
||||
Version: data[0],
|
||||
Opcode: data[1],
|
||||
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
|
||||
Lifetime: binary.BigEndian.Uint32(data[4:8]),
|
||||
Epoch: binary.BigEndian.Uint32(data[8:12]),
|
||||
}
|
||||
|
||||
if resp.Version != Version {
|
||||
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
|
||||
}
|
||||
|
||||
if resp.Opcode&OpReply == 0 {
|
||||
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseMapResponse parses a complete MAP response.
|
||||
func parseMapResponse(data []byte) (*MapResponse, error) {
|
||||
if len(data) < mapRequestSize {
|
||||
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
resp, err := parseResponse(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse header: %w", err)
|
||||
}
|
||||
|
||||
mapResp := &MapResponse{
|
||||
Response: *resp,
|
||||
Protocol: data[36],
|
||||
InternalPort: binary.BigEndian.Uint16(data[40:42]),
|
||||
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
|
||||
ExternalIP: addrFrom16([16]byte(data[44:60])),
|
||||
}
|
||||
copy(mapResp.Nonce[:], data[24:36])
|
||||
|
||||
return mapResp, nil
|
||||
}
|
||||
63
client/internal/portforward/state.go
Normal file
63
client/internal/portforward/state.go
Normal file
@@ -0,0 +1,63 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/libp2p/go-nat"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
|
||||
)
|
||||
|
||||
// discoverGateway is the function used for NAT gateway discovery.
|
||||
// It can be replaced in tests to avoid real network operations.
|
||||
// Tries PCP first, then falls back to NAT-PMP/UPnP.
|
||||
var discoverGateway = defaultDiscoverGateway
|
||||
|
||||
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
|
||||
pcpGateway, err := pcp.DiscoverPCP(ctx)
|
||||
if err == nil {
|
||||
return pcpGateway, nil
|
||||
}
|
||||
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
|
||||
|
||||
return nat.DiscoverGateway(ctx)
|
||||
}
|
||||
|
||||
// State is persisted only for crash recovery cleanup
|
||||
type State struct {
|
||||
InternalPort uint16 `json:"internal_port,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
}
|
||||
|
||||
func (s *State) Name() string {
|
||||
return "port_forward_state"
|
||||
}
|
||||
|
||||
// Cleanup implements statemanager.CleanableState for crash recovery
|
||||
func (s *State) Cleanup() error {
|
||||
if s.InternalPort == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
||||
defer cancel()
|
||||
|
||||
gateway, err := discoverGateway(ctx)
|
||||
if err != nil {
|
||||
// Discovery failure is not an error - gateway may not exist
|
||||
log.Debugf("cleanup: no gateway found: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
|
||||
return fmt.Errorf("delete port mapping: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
cr = append(cr, fakeIPRoute)
|
||||
m.notifier.SetFakeIPRoute(fakeIPRoute)
|
||||
}
|
||||
|
||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
fakeIPRoute *route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
|
||||
// and a separate comparison set (without fake IP blocks) for diff detection.
|
||||
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
|
||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||
n.initialRoutes = filterStatic(initialRoutes)
|
||||
n.currentRoutes = filterStatic(routesForComparison)
|
||||
}
|
||||
|
||||
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
|
||||
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
|
||||
n.fakeIPRoute = r
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
var newRoutes []*route.Route
|
||||
for _, routes := range idMap {
|
||||
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
|
||||
}
|
||||
|
||||
allRoutes := slices.Clone(n.currentRoutes)
|
||||
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
|
||||
if n.fakeIPRoute != nil {
|
||||
allRoutes = append(allRoutes, n.fakeIPRoute)
|
||||
}
|
||||
|
||||
routeStrings := n.routesToStrings(allRoutes)
|
||||
sort.Strings(routeStrings)
|
||||
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
// extraInitialRoutes returns initialRoutes whose network prefix is absent
|
||||
// from currentRoutes (e.g. the fake IP block added at setup time).
|
||||
func (n *Notifier) extraInitialRoutes() []*route.Route {
|
||||
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
|
||||
for _, r := range n.currentRoutes {
|
||||
currentNets[r.Network] = struct{}{}
|
||||
}
|
||||
|
||||
var extra []*route.Route
|
||||
for _, r := range n.initialRoutes {
|
||||
if _, ok := currentNets[r.Network]; !ok {
|
||||
extra = append(extra, r)
|
||||
}
|
||||
}
|
||||
return extra
|
||||
}
|
||||
|
||||
func filterStatic(routes []*route.Route) []*route.Route {
|
||||
out := make([]*route.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
|
||||
@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// iOS doesn't care about initial routes
|
||||
}
|
||||
|
||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
// Not used on iOS
|
||||
}
|
||||
@@ -53,7 +57,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
|
||||
@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
|
||||
|
||||
package systemops
|
||||
|
||||
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
|
||||
// always fall through to the ref-counter exclusion-route path; these stubs
|
||||
// exist only so systemops_unix.go compiles.
|
||||
func (r *SysOps) setupAdvancedRouting() error { return nil }
|
||||
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
|
||||
236
client/internal/routemanager/systemops/scopedroute_darwin.go
Normal file
236
client/internal/routemanager/systemops/scopedroute_darwin.go
Normal file
@@ -0,0 +1,236 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
)
|
||||
|
||||
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
|
||||
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
|
||||
// resolve gateway'd destinations while the VPN's split default owns the
|
||||
// unscoped table.
|
||||
func (r *SysOps) setupAdvancedRouting() error {
|
||||
if err := r.flushScopedDefaults(); err != nil {
|
||||
log.Warnf("flush residual scoped defaults: %v", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
installed := 0
|
||||
|
||||
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
||||
nexthop, err := GetNextHop(unspec)
|
||||
if err != nil {
|
||||
if !errors.Is(err, vars.ErrRouteNotFound) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("get default nexthop for %s: %w", unspec, err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
if nexthop.Intf == nil || !nexthop.IP.IsValid() {
|
||||
log.Debugf("no usable default nexthop for %s (iface=%v gw=%v), skipping scoped default",
|
||||
unspec, nexthop.Intf, nexthop.IP)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add scoped default via %s on %s: %w",
|
||||
nexthop.IP, nexthop.Intf.Name, err))
|
||||
continue
|
||||
}
|
||||
installed++
|
||||
log.Infof("installed scoped default route via %s on %s for %s",
|
||||
nexthop.IP, nexthop.Intf.Name, afOf(unspec))
|
||||
}
|
||||
|
||||
if installed == 0 && merr != nil {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if merr != nil {
|
||||
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupAdvancedRouting() error {
|
||||
return r.flushScopedDefaults()
|
||||
}
|
||||
|
||||
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
|
||||
// Safe to call at startup to clear residual entries from a prior session.
|
||||
func (r *SysOps) flushScopedDefaults() error {
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
removed := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
rtMsg, ok := msg.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||
continue
|
||||
}
|
||||
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := MsgToRoute(rtMsg)
|
||||
if err != nil {
|
||||
log.Debugf("skip scoped flush: %v", err)
|
||||
continue
|
||||
}
|
||||
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.deleteScopedRoute(rtMsg); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
|
||||
info.Dst, rtMsg.Index, err))
|
||||
continue
|
||||
}
|
||||
removed++
|
||||
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
log.Infof("flushed %d residual scoped default route(s)", removed)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
|
||||
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
|
||||
del := &route.RouteMessage{
|
||||
Type: unix.RTM_DELETE,
|
||||
Flags: rtMsg.Flags,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
Index: rtMsg.Index,
|
||||
Addrs: rtMsg.Addrs,
|
||||
}
|
||||
return r.writeRouteMessage(del)
|
||||
}
|
||||
|
||||
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
|
||||
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
|
||||
|
||||
msg := &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: flags,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
Index: nexthop.Intf.Index,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
addrs := make([]route.Addr, numAddrs)
|
||||
|
||||
dst, err := addrToRouteAddr(unspec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build destination: %w", err)
|
||||
}
|
||||
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build netmask: %w", err)
|
||||
}
|
||||
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
|
||||
if err != nil {
|
||||
return fmt.Errorf("build gateway: %w", err)
|
||||
}
|
||||
addrs[unix.RTAX_DST] = dst
|
||||
addrs[unix.RTAX_NETMASK] = mask
|
||||
addrs[unix.RTAX_GATEWAY] = gw
|
||||
msg.Addrs = addrs
|
||||
|
||||
return r.writeRouteMessage(msg)
|
||||
}
|
||||
|
||||
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage) error {
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||
|
||||
return backoff.Retry(func() error { return r.routeMessageRoundtrip(msg) }, expBackOff)
|
||||
}
|
||||
|
||||
func (r *SysOps) routeMessageRoundtrip(msg *route.RouteMessage) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
bytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
|
||||
}
|
||||
|
||||
if err := writeRoute(fd, bytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return readRouteResponse(fd)
|
||||
}
|
||||
|
||||
func writeRoute(fd int, bytes []byte) error {
|
||||
if _, err := unix.Write(fd, bytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readRouteResponse(fd int) error {
|
||||
resp := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, resp)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("read: %w", err))
|
||||
}
|
||||
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
return nil
|
||||
}
|
||||
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
|
||||
if hdr.Errno != 0 {
|
||||
return backoff.Permanent(fmt.Errorf("kernel errno %d", hdr.Errno))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func afOf(a netip.Addr) string {
|
||||
if a.Is4() {
|
||||
return "IPv4"
|
||||
}
|
||||
return "IPv6"
|
||||
}
|
||||
60
client/internal/routemanager/systemops/systemops_darwin.go
Normal file
60
client/internal/routemanager/systemops/systemops_darwin.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/libp2p/go-netroute"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
nbnet.GetBestInterfaceFunc = GetBestInterface
|
||||
}
|
||||
|
||||
// GetBestInterface finds the best interface for reaching a destination,
|
||||
// excluding the VPN interface to avoid routing loops.
|
||||
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
|
||||
var skipInterfaceIndex int
|
||||
if vpnIntf != "" {
|
||||
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
|
||||
skipInterfaceIndex = iface.Index
|
||||
} else {
|
||||
log.Warnf("get VPN interface %s: %v", vpnIntf, err)
|
||||
}
|
||||
}
|
||||
|
||||
r, err := netroute.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new netroute: %w", err)
|
||||
}
|
||||
|
||||
router, ok := r.(interface {
|
||||
RouteExcluding(dst net.IP, skipIfaceIndex int) (iface *net.Interface, gateway, preferredSrc net.IP, err error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("router does not support RouteExcluding")
|
||||
}
|
||||
|
||||
intf, gateway, preferredSrc, err := router.RouteExcluding(dest.AsSlice(), skipInterfaceIndex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find route for %s: %w", dest, err)
|
||||
}
|
||||
|
||||
if intf.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
|
||||
return nil, fmt.Errorf("route points to loopback interface")
|
||||
}
|
||||
if intf.Flags&net.FlagUp == 0 {
|
||||
return nil, fmt.Errorf("interface %s is down", intf.Name)
|
||||
}
|
||||
|
||||
log.Debugf("route lookup for %s: selected interface %s (index %d), gateway %v, src %v",
|
||||
dest, intf.Name, intf.Index, gateway, preferredSrc)
|
||||
|
||||
return intf, nil
|
||||
}
|
||||
@@ -41,10 +41,19 @@ func init() {
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return r.setupAdvancedRouting()
|
||||
}
|
||||
|
||||
log.Infof("Using legacy routing setup with ref counters")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return r.cleanupAdvancedRouting()
|
||||
}
|
||||
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
hostDNS := []netip.AddrPort{
|
||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
||||
}
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
5
client/net/dialer_init_darwin.go
Normal file
5
client/net/dialer_init_darwin.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = applyBoundIfToSocket
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package net
|
||||
|
||||
|
||||
67
client/net/env_darwin.go
Normal file
67
client/net/env_darwin.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build darwin
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
var (
|
||||
vpnInterfaceName string
|
||||
vpnInitMutex sync.RWMutex
|
||||
|
||||
advancedRoutingSupported bool
|
||||
)
|
||||
|
||||
func Init() {
|
||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||
}
|
||||
|
||||
func checkAdvancedRoutingSupport() bool {
|
||||
var err error
|
||||
var legacyRouting bool
|
||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||
legacyRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||
}
|
||||
}
|
||||
|
||||
if legacyRouting || netstack.IsEnabled() {
|
||||
log.Info("advanced routing has been requested to be disabled")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Info("system supports advanced routing")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||
func AdvancedRouting() bool {
|
||||
return advancedRoutingSupported
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns the stored VPN interface name
|
||||
func GetVPNInterfaceName() string {
|
||||
vpnInitMutex.RLock()
|
||||
defer vpnInitMutex.RUnlock()
|
||||
return vpnInterfaceName
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName sets the VPN interface name for lazy initialization
|
||||
func SetVPNInterfaceName(name string) {
|
||||
vpnInitMutex.Lock()
|
||||
defer vpnInitMutex.Unlock()
|
||||
vpnInterfaceName = name
|
||||
|
||||
if name != "" {
|
||||
log.Infof("VPN interface name set to %s for route exclusion", name)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows && !android
|
||||
//go:build !linux && !windows && !android && !darwin
|
||||
|
||||
package net
|
||||
|
||||
|
||||
5
client/net/listener_init_darwin.go
Normal file
5
client/net/listener_init_darwin.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = applyBoundIfToSocket
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package net
|
||||
|
||||
|
||||
212
client/net/net_darwin.go
Normal file
212
client/net/net_darwin.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
// IP_BOUND_IF binds a socket to a specific interface for egress.
|
||||
IP_BOUND_IF = 25
|
||||
// IPV6_BOUND_IF is the IPPROTO_IPV6 equivalent.
|
||||
IPV6_BOUND_IF = 125
|
||||
)
|
||||
|
||||
// GetBestInterfaceFunc is set at runtime by the routemanager to avoid an import cycle.
|
||||
var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
|
||||
|
||||
func parseDestinationAddress(network, address string) (netip.Addr, error) {
|
||||
if address == "" {
|
||||
if strings.HasSuffix(network, "6") {
|
||||
return netip.IPv6Unspecified(), nil
|
||||
}
|
||||
return netip.IPv4Unspecified(), nil
|
||||
}
|
||||
|
||||
if addrPort, err := netip.ParseAddrPort(address); err == nil {
|
||||
return addrPort.Addr(), nil
|
||||
}
|
||||
if dest, err := netip.ParseAddr(address); err == nil {
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
host = address
|
||||
}
|
||||
if host == "" {
|
||||
if strings.HasSuffix(network, "6") {
|
||||
return netip.IPv6Unspecified(), nil
|
||||
}
|
||||
return netip.IPv4Unspecified(), nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
|
||||
}
|
||||
|
||||
dest, ok := netip.AddrFromSlice(ips[0].IP)
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
|
||||
}
|
||||
if ips[0].Zone != "" {
|
||||
dest = dest.WithZone(ips[0].Zone)
|
||||
}
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func getInterfaceFromZone(zone string) *net.Interface {
|
||||
if zone == "" {
|
||||
return nil
|
||||
}
|
||||
if iface, err := net.InterfaceByName(zone); err == nil {
|
||||
return iface
|
||||
}
|
||||
if idx, err := strconv.Atoi(zone); err == nil {
|
||||
if iface, err := net.InterfaceByIndex(idx); err == nil {
|
||||
return iface
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type interfaceSelection struct {
|
||||
iface4 *net.Interface
|
||||
iface6 *net.Interface
|
||||
}
|
||||
|
||||
func selectInterfaceForUnspecified() (*interfaceSelection, error) {
|
||||
if GetBestInterfaceFunc == nil {
|
||||
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||
}
|
||||
|
||||
vpn := GetVPNInterfaceName()
|
||||
var result interfaceSelection
|
||||
|
||||
if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpn); err == nil {
|
||||
result.iface4 = iface4
|
||||
} else {
|
||||
log.Debugf("no IPv4 default route found: %v", err)
|
||||
}
|
||||
if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpn); err == nil {
|
||||
result.iface6 = iface6
|
||||
} else {
|
||||
log.Debugf("no IPv6 default route found: %v", err)
|
||||
}
|
||||
|
||||
if result.iface4 == nil && result.iface6 == nil {
|
||||
return nil, errors.New("no default routes found")
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
|
||||
if zone := dest.Zone(); zone != "" {
|
||||
if iface := getInterfaceFromZone(zone); iface != nil {
|
||||
if dest.Is6() {
|
||||
return &interfaceSelection{iface6: iface}, nil
|
||||
}
|
||||
return &interfaceSelection{iface4: iface}, nil
|
||||
}
|
||||
}
|
||||
|
||||
if dest.IsUnspecified() {
|
||||
return selectInterfaceForUnspecified()
|
||||
}
|
||||
|
||||
if GetBestInterfaceFunc == nil {
|
||||
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||
}
|
||||
|
||||
iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find route for %s: %w", dest, err)
|
||||
}
|
||||
if dest.Is6() {
|
||||
return &interfaceSelection{iface6: iface}, nil
|
||||
}
|
||||
return &interfaceSelection{iface4: iface}, nil
|
||||
}
|
||||
|
||||
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, IP_BOUND_IF, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, IPV6_BOUND_IF, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setBoundIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||
switch {
|
||||
case strings.HasSuffix(network, "4"):
|
||||
if selection.iface4 == nil {
|
||||
return nil
|
||||
}
|
||||
if err := setIPv4BoundIf(fd, selection.iface4); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("set IP_BOUND_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||
return nil
|
||||
case strings.HasSuffix(network, "6"):
|
||||
if selection.iface6 == nil {
|
||||
return nil
|
||||
}
|
||||
if err := setIPv6BoundIf(fd, selection.iface6); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debugf("set IPV6_BOUND_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unexpected network type: %s", network)
|
||||
}
|
||||
|
||||
// applyBoundIfToSocket binds the socket to the physical egress interface for the
|
||||
// destination, bypassing the kernel routing decision that would otherwise send
|
||||
// the traffic through the VPN utun and loop forever.
|
||||
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
dest, err := parseDestinationAddress(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dest = dest.Unmap()
|
||||
if !dest.IsValid() {
|
||||
return fmt.Errorf("invalid destination address for %s", address)
|
||||
}
|
||||
|
||||
selection, err := selectInterface(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var controlErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
controlErr = setBoundIf(fd, network, selection, address)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
@@ -4979,6 +4979,7 @@ type GetFeaturesResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"`
|
||||
DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"`
|
||||
DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -5027,6 +5028,13 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetFeaturesResponse) GetDisableNetworks() bool {
|
||||
if x != nil {
|
||||
return x.DisableNetworks
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type TriggerUpdateRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
@@ -6472,10 +6480,11 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_username\"\x10\n" +
|
||||
"\x0eLogoutResponse\"\x14\n" +
|
||||
"\x12GetFeaturesRequest\"x\n" +
|
||||
"\x12GetFeaturesRequest\"\xa3\x01\n" +
|
||||
"\x13GetFeaturesResponse\x12)\n" +
|
||||
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
|
||||
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" +
|
||||
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" +
|
||||
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" +
|
||||
"\x14TriggerUpdateRequest\"M\n" +
|
||||
"\x15TriggerUpdateResponse\x12\x18\n" +
|
||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||
|
||||
@@ -727,6 +727,7 @@ message GetFeaturesRequest{}
|
||||
message GetFeaturesResponse{
|
||||
bool disable_profiles = 1;
|
||||
bool disable_update_settings = 2;
|
||||
bool disable_networks = 3;
|
||||
}
|
||||
|
||||
message TriggerUpdateRequest {}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
@@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ const (
|
||||
errRestoreResidualState = "failed to restore residual state: %v"
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||
errNetworksDisabled = "network selection is disabled by the administrator"
|
||||
)
|
||||
|
||||
var ErrServiceNotUp = errors.New("service is not up")
|
||||
@@ -88,6 +89,7 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
networksDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
@@ -104,7 +106,7 @@ type oauthAuthFlow struct {
|
||||
}
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
logFile: logFile,
|
||||
@@ -113,6 +115,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
networksDisabled: networksDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
agent := &serverAgent{s}
|
||||
@@ -1359,6 +1362,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
|
||||
if engine.IsBlockInbound() {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||
@@ -1624,6 +1631,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
DisableNetworks: s.networksDisabled,
|
||||
}
|
||||
|
||||
return features, nil
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "debug", "", false, false)
|
||||
s := New(ctx, "debug", "", false, false, false)
|
||||
|
||||
s.config = config
|
||||
|
||||
@@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false)
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
t.Fatalf("failed to set active profile state: %v", err)
|
||||
}
|
||||
|
||||
s := New(ctx, "console", "", false, false)
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
@@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
|
||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
s := New(ctx, "console", "", false, false)
|
||||
s := New(ctx, "console", "", false, false, false)
|
||||
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||
if err != nil {
|
||||
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
}
|
||||
|
||||
if hasCommand {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
if err := serverSession.Run(session.RawCommand()); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user