mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 23:53:55 -04:00
Compare commits
116 Commits
feature/st
...
test-ldfla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
619f1588b3 | ||
|
|
60d2a2c7df | ||
|
|
ff9585735b | ||
|
|
194951c88d | ||
|
|
d63f2e5196 | ||
|
|
14b5637555 | ||
|
|
3d3b05c157 | ||
|
|
f36e206238 | ||
|
|
9a808244d7 | ||
|
|
2dec76f8ea | ||
|
|
224bd8ff22 | ||
|
|
fe88a5662e | ||
|
|
f9f6409f94 | ||
|
|
b03154dce5 | ||
|
|
c57364596a | ||
|
|
60f4d5f9b0 | ||
|
|
4eeb2d8deb | ||
|
|
2765bcfb89 | ||
|
|
d71a82769c | ||
|
|
fa6151b849 | ||
|
|
a939c1767c | ||
|
|
938554fb0f | ||
|
|
0d79301141 | ||
|
|
39bec2dd74 | ||
|
|
554c9bcf4b | ||
|
|
f3639675e7 | ||
|
|
a1457f541b | ||
|
|
9cdfb0d78c | ||
|
|
22d796097e | ||
|
|
aa39a5d528 | ||
|
|
e4b41d0ad7 | ||
|
|
9cc9462cd5 | ||
|
|
3176b53968 | ||
|
|
27957036c9 | ||
|
|
6fb568728f | ||
|
|
cc97cffff1 | ||
|
|
1d2a5371ce | ||
|
|
6898e57686 | ||
|
|
c8bc865f2f | ||
|
|
06bb8658b1 | ||
|
|
8fc4fed3a0 | ||
|
|
df14f1399f | ||
|
|
6d6f090764 | ||
|
|
c28275611b | ||
|
|
56f169eede | ||
|
|
07cf9d5895 | ||
|
|
7df49e249d | ||
|
|
dbfc8a52c9 | ||
|
|
98ddac07bf | ||
|
|
48475ddc05 | ||
|
|
6aa4ba7af4 | ||
|
|
2e16c9914a | ||
|
|
5c29d395b2 | ||
|
|
229e0038ee | ||
|
|
75327d9519 | ||
|
|
c92e6c1b5f | ||
|
|
641eb5140b | ||
|
|
45c25dca84 | ||
|
|
679c58ce47 | ||
|
|
13febbbfca | ||
|
|
49d36b7e7e | ||
|
|
719283c792 | ||
|
|
a2313a5ba4 | ||
|
|
8c108ccad3 | ||
|
|
86eff0d750 | ||
|
|
976787dbf1 | ||
|
|
43c9a51913 | ||
|
|
c530db1455 | ||
|
|
1ee575befe | ||
|
|
d3a34adcc9 | ||
|
|
d7321c130b | ||
|
|
404cab90ba | ||
|
|
4545ab9a52 | ||
|
|
7f08983207 | ||
|
|
eddea14521 | ||
|
|
b9ef214ea5 | ||
|
|
709e24eb6f | ||
|
|
6654e2dbf7 | ||
|
|
d80d47a469 | ||
|
|
96f71ff1e1 | ||
|
|
2fe2af38d2 | ||
|
|
536b0003ab | ||
|
|
cd9a867ad0 | ||
|
|
0f9bfeff7c | ||
|
|
f5301230bf | ||
|
|
429d7d6585 | ||
|
|
3cdb10cde7 | ||
|
|
af95aabb03 | ||
|
|
3abae0bd17 | ||
|
|
8252ff41db | ||
|
|
277aa2b7cc | ||
|
|
bb37dc89ce | ||
|
|
000e99e7f3 | ||
|
|
0d2e67983a | ||
|
|
0e9438d658 | ||
|
|
e570570fe5 | ||
|
|
23f9dd04b8 | ||
|
|
5151f19d29 | ||
|
|
7a95bf5652 | ||
|
|
bedd3cabc9 | ||
|
|
d35a845dbd | ||
|
|
4e03f708a4 | ||
|
|
654aa9581d | ||
|
|
9021bb512b | ||
|
|
768332820e | ||
|
|
229c65ffa1 | ||
|
|
4d33567888 | ||
|
|
88467883fc | ||
|
|
954f40991f | ||
|
|
34341d95a9 | ||
|
|
60c5782905 | ||
|
|
5a12c5d345 | ||
|
|
bdae55ab79 | ||
|
|
d01c3d5011 | ||
|
|
17a2af96ea | ||
|
|
cc595da1ad |
119
.github/workflows/check-license-dependencies.yml
vendored
119
.github/workflows/check-license-dependencies.yml
vendored
@@ -3,39 +3,108 @@ name: Check License Dependencies
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
|
||||
jobs:
|
||||
check-dependencies:
|
||||
check-internal-dependencies:
|
||||
name: Check Internal AGPL Dependencies
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All internal license dependencies are clean"
|
||||
fi
|
||||
|
||||
check-external-licenses:
|
||||
name: Check External GPL/AGPL Licenses
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ ! -z "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo ""
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages will change license and should not be imported by client or shared code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All license dependencies are clean"
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/wasm-build-validation.yml
vendored
2
.github/workflows/wasm-build-validation.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm -ldflags="-s -w" ./client/wasm/cmd
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
- name: Check Wasm build size
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.0
|
||||
FROM alpine:3.22.2
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetEnableSSHRoot reads SSH root login setting from config file
|
||||
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
|
||||
if p.configInput.EnableSSHRoot != nil {
|
||||
return *p.configInput.EnableSSHRoot, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRoot == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRoot, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRoot stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
|
||||
p.configInput.EnableSSHRoot = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHSFTP reads SSH SFTP setting from config file
|
||||
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
|
||||
if p.configInput.EnableSSHSFTP != nil {
|
||||
return *p.configInput.EnableSSHSFTP, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHSFTP == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHSFTP, err
|
||||
}
|
||||
|
||||
// SetEnableSSHSFTP stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
|
||||
p.configInput.EnableSSHSFTP = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHLocalPortForwarding != nil {
|
||||
return *p.configInput.EnableSSHLocalPortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHLocalPortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHLocalPortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHLocalPortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHRemotePortForwarding != nil {
|
||||
return *p.configInput.EnableSSHRemotePortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRemotePortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRemotePortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHRemotePortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
@@ -98,7 +97,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -168,7 +166,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -220,9 +218,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
}
|
||||
@@ -230,11 +225,8 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: systemInfoFlag,
|
||||
LogFileCount: logFileCount,
|
||||
}
|
||||
@@ -301,19 +293,6 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -372,7 +351,7 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogFile: logFilePath,
|
||||
LogPath: logFilePath,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -105,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -240,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -268,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -285,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -314,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -356,13 +373,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
if err := openBrowser(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||
func openBrowser(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
return open.Run(url)
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -35,7 +35,6 @@ const (
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
@@ -64,7 +63,6 @@ var (
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
@@ -176,7 +174,6 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
|
||||
if err := configureSystemdNetworkd(); err != nil {
|
||||
log.Warnf("failed to configure systemd-networkd: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
|
||||
return fmt.Errorf("uninstall service: %w", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
if err := cleanupSystemdNetworkd(); err != nil {
|
||||
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("NetBird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
@@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) {
|
||||
|
||||
return status == service.StatusRunning, nil
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
# routes and policy rules managed by NetBird.
|
||||
|
||||
[Network]
|
||||
ManageForeignRoutes=no
|
||||
ManageForeignRoutingPolicyRules=no
|
||||
`
|
||||
)
|
||||
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
|
||||
func cleanupSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := os.Remove(networkdConfFile); err != nil {
|
||||
return fmt.Errorf("remove networkd configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,125 +3,809 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
userName = "root"
|
||||
host string
|
||||
const (
|
||||
sshUsernameDesc = "SSH username"
|
||||
hostArgumentRequired = "host argument required"
|
||||
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
enableSSHRootFlag = "enable-ssh-root"
|
||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
}
|
||||
var (
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
localForwards []string
|
||||
remoteForwards []string
|
||||
strictHostKeyChecking bool
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
requestPTY bool
|
||||
)
|
||||
|
||||
split := strings.Split(args[0], "@")
|
||||
if len(split) == 2 {
|
||||
userName = split[0]
|
||||
host = split[1]
|
||||
} else {
|
||||
host = args[0]
|
||||
}
|
||||
var (
|
||||
serverSSHAllowed bool
|
||||
enableSSHRoot bool
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
Short: "Connect to a remote SSH server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
|
||||
if !util.IsAdmin() {
|
||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sm := profilemanager.NewServiceManager(configPath)
|
||||
activeProf, err := sm.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
profPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile path: %v", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.ReadConfig(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read profile config: %v", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
sshCmd.AddCommand(sshSftpCmd)
|
||||
sshCmd.AddCommand(sshProxyCmd)
|
||||
sshCmd.AddCommand(sshDetectCmd)
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
return
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [flags] [user@]host [command]",
|
||||
Short: "Connect to a NetBird peer via SSH",
|
||||
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||
|
||||
Port Forwarding:
|
||||
-L [bind_address:]port:host:hostport Local port forwarding
|
||||
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||
|
||||
SSH Options:
|
||||
-p, --port int Remote SSH port (default 22)
|
||||
-u, --user string SSH username
|
||||
--login string SSH username (alias for --user)
|
||||
-t, --tty Force pseudo-terminal allocation
|
||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||
-o, --known-hosts string Path to known_hosts file
|
||||
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh root@peer-hostname
|
||||
netbird ssh --login root peer-hostname
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami
|
||||
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||
DisableFlagParsing: true,
|
||||
Args: validateSSHArgsWithoutFlagParsing,
|
||||
RunE: sshFn,
|
||||
Aliases: []string{"ssh"},
|
||||
}
|
||||
|
||||
func sshFn(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
}
|
||||
}
|
||||
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err = c.OpenTerminal()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-sshctx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
||||
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||
var localForwardFlags []string
|
||||
var remoteForwardFlags []string
|
||||
var filteredArgs []string
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case strings.HasPrefix(arg, "-L"):
|
||||
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
|
||||
case strings.HasPrefix(arg, "-R"):
|
||||
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
|
||||
default:
|
||||
filteredArgs = append(filteredArgs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||
}
|
||||
|
||||
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
|
||||
if arg == "-L" || arg == "-R" {
|
||||
if i+1 < len(args) {
|
||||
flags = append(flags, args[i+1])
|
||||
i++
|
||||
}
|
||||
} else if len(arg) > 2 {
|
||||
flags = append(flags, arg[2:])
|
||||
}
|
||||
return flags, i
|
||||
}
|
||||
|
||||
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||
func extractGlobalFlags(args []string) {
|
||||
sshPos := findSSHCommandPosition(args)
|
||||
if sshPos == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
globalArgs := args[:sshPos]
|
||||
parseGlobalArgs(globalArgs)
|
||||
}
|
||||
|
||||
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||
func findSSHCommandPosition(args []string) int {
|
||||
for i, arg := range args {
|
||||
if arg == "ssh" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
const (
|
||||
configFlag = "config"
|
||||
logLevelFlag = "log-level"
|
||||
logFileFlag = "log-file"
|
||||
)
|
||||
|
||||
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||
func parseGlobalArgs(globalArgs []string) {
|
||||
flagHandlers := map[string]func(string){
|
||||
configFlag: func(value string) { configPath = value },
|
||||
logLevelFlag: func(value string) { logLevel = value },
|
||||
logFileFlag: func(value string) {
|
||||
if !slices.Contains(logFiles, value) {
|
||||
logFiles = append(logFiles, value)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shortFlags := map[string]string{
|
||||
"c": configFlag,
|
||||
"l": logLevelFlag,
|
||||
}
|
||||
|
||||
for i := 0; i < len(globalArgs); i++ {
|
||||
arg := globalArgs[i]
|
||||
|
||||
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||
i = nextIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseFlag handles generic flag parsing for both long and short forms
|
||||
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex
|
||||
}
|
||||
|
||||
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex + 1
|
||||
}
|
||||
|
||||
return false, currentIndex
|
||||
}
|
||||
|
||||
type parsedFlag struct {
|
||||
flagName string
|
||||
value string
|
||||
}
|
||||
|
||||
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if !strings.Contains(arg, "=") {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(arg, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "--") {
|
||||
flagName := strings.TrimPrefix(parts[0], "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// parseSpacedFormat handles --flag value and -f value formats
|
||||
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if currentIndex+1 >= len(args) {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
flagName := strings.TrimPrefix(arg, "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||
shortFlag := strings.TrimPrefix(arg, "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
// sshFlags contains all SSH-related flags and parameters
|
||||
type sshFlags struct {
|
||||
Port int
|
||||
Username string
|
||||
Login string
|
||||
RequestPTY bool
|
||||
StrictHostKeyChecking bool
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
RemoteForwards []string
|
||||
Host string
|
||||
Command string
|
||||
}
|
||||
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
|
||||
flags := &sshFlags{}
|
||||
|
||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||
|
||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||
|
||||
return fs, flags
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
resetSSHGlobals()
|
||||
|
||||
if len(os.Args) > 2 {
|
||||
extractGlobalFlags(os.Args[1:])
|
||||
}
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, flags := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
if errors.Is(err, flag.ErrHelp) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
remaining := fs.Args()
|
||||
if len(remaining) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
port = flags.Port
|
||||
if flags.Username != "" {
|
||||
username = flags.Username
|
||||
} else if flags.Login != "" {
|
||||
username = flags.Login
|
||||
}
|
||||
|
||||
requestPTY = flags.RequestPTY
|
||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
}
|
||||
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = flags.LogLevel
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
remoteForwards = remoteForwardFlags
|
||||
|
||||
return parseHostnameAndCommand(remaining)
|
||||
}
|
||||
|
||||
func parseHostnameAndCommand(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
arg := args[0]
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
if username == "" {
|
||||
username = parts[0]
|
||||
}
|
||||
host = parts[1]
|
||||
} else {
|
||||
host = arg
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
username = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
username = currentUser.Username
|
||||
} else {
|
||||
username = "root"
|
||||
}
|
||||
}
|
||||
|
||||
// Everything after hostname becomes the command
|
||||
if len(args) > 1 {
|
||||
command = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
|
||||
sshCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-sshCtx.Done()
|
||||
if err := c.Close(); err != nil {
|
||||
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||
return fmt.Errorf("start port forwarding: %w", err)
|
||||
}
|
||||
|
||||
if command != "" {
|
||||
return executeSSHCommand(sshCtx, c, command)
|
||||
}
|
||||
return openSSHTerminal(sshCtx, c)
|
||||
}
|
||||
|
||||
// executeSSHCommand executes a command over SSH.
|
||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||
var err error
|
||||
if requestPTY {
|
||||
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||
} else {
|
||||
err = c.ExecuteCommandWithIO(ctx, command)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
os.Exit(exitErr.ExitStatus())
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote command exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openSSHTerminal opens an interactive SSH terminal.
|
||||
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("open terminal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||
for _, forward := range localForwards {
|
||||
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, forward := range remoteForwards {
|
||||
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Local port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Remote port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
// Support formats:
|
||||
// port:host:hostport -> localhost:port -> host:hostport
|
||||
// host:port:host:hostport -> host:port -> host:hostport
|
||||
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||
|
||||
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||
return parseIPv6ForwardSpec(spec)
|
||||
}
|
||||
|
||||
parts := strings.Split(spec, ":")
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||
}
|
||||
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
return parseTwoPartForwardSpec(parts, spec)
|
||||
case 3:
|
||||
return parseThreePartForwardSpec(parts)
|
||||
case 4:
|
||||
return parseFourPartForwardSpec(parts)
|
||||
default:
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||
}
|
||||
}
|
||||
|
||||
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||
if isUnixSocket(parts[1]) {
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||
}
|
||||
|
||||
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||
if isUnixSocket(parts[2]) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2] + ":" + parts[3]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||
idx := strings.Index(spec, "]:")
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||
}
|
||||
|
||||
ipv6Host := spec[:idx+1]
|
||||
remaining := spec[idx+2:]
|
||||
|
||||
parts := strings.Split(remaining, ":")
|
||||
if len(parts) != 3 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||
}
|
||||
|
||||
localAddr := ipv6Host + ":" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// isUnixSocket checks if a path is a Unix socket path.
|
||||
func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
var sshProxyCmd = &cobra.Command{
|
||||
Use: "proxy <host> <port>",
|
||||
Short: "Internal SSH proxy for native SSH client integration",
|
||||
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshProxyFn,
|
||||
}
|
||||
|
||||
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.Close(); err != nil {
|
||||
log.Debugf("close SSH proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshDetectCmd = &cobra.Command{
|
||||
Use: "detect <host> <port>",
|
||||
Short: "Detect if a host is running NetBird SSH",
|
||||
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshDetectFn,
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sshExecUID uint32
|
||||
sshExecGID uint32
|
||||
sshExecGroups []uint
|
||||
sshExecWorkingDir string
|
||||
sshExecShell string
|
||||
sshExecCommand string
|
||||
sshExecPTY bool
|
||||
)
|
||||
|
||||
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||
var sshExecCmd = &cobra.Command{
|
||||
Use: "exec",
|
||||
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||
Hidden: true,
|
||||
RunE: runSSHExec,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||
|
||||
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sshCmd.AddCommand(sshExecCmd)
|
||||
}
|
||||
|
||||
// runSSHExec handles the SSH exec subcommand execution.
|
||||
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sshExecGroups {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sshExecUID,
|
||||
GID: sshExecGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sshExecWorkingDir,
|
||||
Shell: sshExecShell,
|
||||
Command: sshExecCommand,
|
||||
PTY: sshExecPTY,
|
||||
}
|
||||
|
||||
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpUID uint32
|
||||
sftpGID uint32
|
||||
sftpGroupsInt []uint
|
||||
sftpWorkingDir string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with privilege dropping (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sftpGroupsInt {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sftpUID,
|
||||
GID: sftpGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sftpWorkingDir,
|
||||
Shell: "",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Tracef("starting SFTP server with dropped privileges")
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeSuccess)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_windows.go
Normal file
94
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpWorkingDir string
|
||||
windowsUsername string
|
||||
windowsDomain string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with user switching for Windows (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
return sftpMainDirect(cmd)
|
||||
}
|
||||
|
||||
func sftpMainDirect(cmd *cobra.Command) error {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
|
||||
if windowsUsername != "" {
|
||||
expectedUsername := windowsUsername
|
||||
if windowsDomain != "" {
|
||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||
}
|
||||
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||
cmd.PrintErrf("user switching failed\n")
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||
|
||||
if sftpWorkingDir != "" {
|
||||
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Debugf("starting SFTP server")
|
||||
exitCode := sshserver.ExitCodeSuccess
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
exitCode = sshserver.ExitCodeShellExecFail
|
||||
}
|
||||
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("SFTP server close error: %v", err)
|
||||
}
|
||||
|
||||
os.Exit(exitCode)
|
||||
return nil
|
||||
}
|
||||
717
client/cmd/ssh_test.go
Normal file
717
client/cmd/ssh_test.go
Normal file
@@ -0,0 +1,717 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedUser string
|
||||
expectedPort int
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "basic host",
|
||||
args: []string{"hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "user@host format",
|
||||
args: []string{"user@hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "user",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "host with command",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "command with flags should be preserved",
|
||||
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "ls -la /tmp",
|
||||
},
|
||||
{
|
||||
name: "double dash separator",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "-- ls -la",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
// Mock command for testing
|
||||
cmd := sshCmd
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
if tt.expectedUser != "" {
|
||||
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||
}
|
||||
assert.Equal(t, tt.expectedPort, port, "port mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||
// Test that SSH flags don't conflict with command flags
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flags",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "ls flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "grep with -r flag",
|
||||
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
|
||||
expectedCmd: "grep -r pattern /path",
|
||||
description: "grep flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "ps with aux flags",
|
||||
args: []string{"hostname", "ps", "aux"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "ps flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "command with double dash",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedCmd: "-- ls -la",
|
||||
description: "double dash should be preserved in command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||
// Test that commands with arguments should execute the command and exit,
|
||||
// not drop to an interactive shell
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
shouldExit bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls command should execute and exit",
|
||||
args: []string{"hostname", "ls"},
|
||||
expectedCmd: "ls",
|
||||
shouldExit: true,
|
||||
description: "ls command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "ls with flags should execute and exit",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
shouldExit: true,
|
||||
description: "ls with flags should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "pwd command should execute and exit",
|
||||
args: []string{"hostname", "pwd"},
|
||||
expectedCmd: "pwd",
|
||||
shouldExit: true,
|
||||
description: "pwd command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "echo command should execute and exit",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedCmd: "echo hello",
|
||||
shouldExit: true,
|
||||
description: "echo command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "no command should open shell",
|
||||
args: []string{"hostname"},
|
||||
expectedCmd: "",
|
||||
shouldExit: false,
|
||||
description: "no command should open interactive shell",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// When command is present, it should execute the command and exit
|
||||
// When command is empty, it should open interactive shell
|
||||
hasCommand := command != ""
|
||||
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||
// Test that flags after hostname are not parsed by netbird but passed to SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flag should not be parsed by netbird",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
|
||||
},
|
||||
{
|
||||
name: "command with netbird-like flags should be passed through",
|
||||
args: []string{"hostname", "echo", "--help"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "echo --help",
|
||||
expectError: false,
|
||||
description: "--help should be passed to echo, not parsed by netbird",
|
||||
},
|
||||
{
|
||||
name: "command with -p flag should not conflict with SSH port flag",
|
||||
args: []string{"hostname", "ps", "-p", "1234"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ps -p 1234",
|
||||
expectError: false,
|
||||
description: "ps -p should be passed to ps command, not parsed as port",
|
||||
},
|
||||
{
|
||||
name: "tar with flags should be passed through",
|
||||
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "tar -czf backup.tar.gz /home",
|
||||
expectError: false,
|
||||
description: "tar flags should be passed to tar command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
|
||||
// should not parse -la as netbird flags but pass them to the SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "original issue: ls -la should be preserved",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "The original failing case should now work",
|
||||
},
|
||||
{
|
||||
name: "ls -l should be preserved",
|
||||
args: []string{"hostname", "ls", "-l"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -l",
|
||||
expectError: false,
|
||||
description: "Single letter flags should be preserved",
|
||||
},
|
||||
{
|
||||
name: "SSH port flag should work",
|
||||
args: []string{"-p", "2222", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "SSH -p flag should be parsed, command flags preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// Check port for the test case with -p flag
|
||||
if len(tt.args) > 0 && tt.args[0] == "-p" {
|
||||
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local port forwarding -L",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Single -L flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "remote port forwarding -R",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectError: false,
|
||||
description: "Single -R flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple local port forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Multiple -L flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple remote port forwards",
|
||||
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Multiple -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed local and remote forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Mixed -L and -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with bind address",
|
||||
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with command should work",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec string
|
||||
expectedLocal string
|
||||
expectedRemote string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "simple port forward",
|
||||
spec: "8080:localhost:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Simple port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with bind address",
|
||||
spec: "127.0.0.1:8080:localhost:80",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "bind_address:port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward to different host",
|
||||
spec: "8080:example.com:443",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "example.com:443",
|
||||
expectError: false,
|
||||
description: "Forwarding to different host should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with IPv6 (needs bracket support)",
|
||||
spec: "::1:8080:localhost:80",
|
||||
expectError: true,
|
||||
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too few parts",
|
||||
spec: "8080:localhost",
|
||||
expectError: true,
|
||||
description: "Invalid format with too few parts should fail",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too many parts",
|
||||
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||
expectError: true,
|
||||
description: "Invalid format with too many parts should fail",
|
||||
},
|
||||
{
|
||||
name: "empty spec",
|
||||
spec: "",
|
||||
expectError: true,
|
||||
description: "Empty spec should fail",
|
||||
},
|
||||
{
|
||||
name: "unix socket local forward",
|
||||
spec: "8080:/tmp/socket",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket forwarding should work",
|
||||
},
|
||||
{
|
||||
name: "unix socket with bind address",
|
||||
spec: "127.0.0.1:8080:/tmp/socket",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
spec: "8080:*:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "*:80",
|
||||
expectError: false,
|
||||
description: "Wildcard in remote host should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||
// Integration test for port forwarding with the actual SSH command implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local forward with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectedCmd: "echo test",
|
||||
description: "Local forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "remote forward with command",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "Remote forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "multiple forwards with user and command",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "Complex case with multiple forwards, user, and command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
}{
|
||||
{
|
||||
name: "cmd flag passed as command",
|
||||
args: []string{"hostname", "--cmd", "echo test"},
|
||||
expectedCmd: "--cmd echo test",
|
||||
},
|
||||
{
|
||||
name: "uid flag passed as command",
|
||||
args: []string{"hostname", "--uid", "1000"},
|
||||
expectedCmd: "--uid 1000",
|
||||
},
|
||||
{
|
||||
name: "shell flag passed as command",
|
||||
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||
expectedCmd: "--shell /bin/bash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "hostname", host)
|
||||
assert.Equal(t, tt.expectedCmd, command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "invalid long flag before hostname",
|
||||
args: []string{"--invalid-flag", "hostname"},
|
||||
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||
},
|
||||
{
|
||||
name: "invalid short flag before hostname",
|
||||
args: []string{"-x", "hostname"},
|
||||
description: "Invalid short flag should return parse error",
|
||||
},
|
||||
{
|
||||
name: "invalid flag with value before hostname",
|
||||
args: []string{"--invalid-option=value", "hostname"},
|
||||
description: "Invalid flag with value should return parse error",
|
||||
},
|
||||
{
|
||||
name: "typo in known flag",
|
||||
args: []string{"--por", "2222", "hostname"},
|
||||
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
|
||||
// Should return an error for invalid flags
|
||||
assert.Error(t, err, tt.description)
|
||||
|
||||
// Should not have set host to the invalid flag
|
||||
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
resp, err := getStatus(ctx)
|
||||
resp, err := getStatus(ctx, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -84,7 +89,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
jobManager := job.NewJobManager(nil, store)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
@@ -110,13 +115,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, jobManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -200,7 +200,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager.ProfileManager, activeProf *profilemanager.Profile, profileSwitched bool) error {
|
||||
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -348,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
@@ -432,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -532,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
@@ -18,12 +18,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
var (
|
||||
ErrClientAlreadyStarted = errors.New("client already started")
|
||||
ErrClientNotStarted = errors.New("client not started")
|
||||
ErrEngineNotStarted = errors.New("engine not started")
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
@@ -169,6 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
@@ -176,7 +181,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
run := make(chan struct{})
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
if err := client.Run(run, ""); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
@@ -238,17 +243,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
@@ -259,6 +256,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return c.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
@@ -314,18 +316,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return sshcommon.ErrPeerNotFound
|
||||
}
|
||||
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect == nil {
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
return nil, ErrEngineNotStarted
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
}
|
||||
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Info("creating an iptables firewall manager")
|
||||
return nbiptables.Create(iface)
|
||||
return nbiptables.Create(iface, mtu)
|
||||
case NFTABLES:
|
||||
log.Info("creating an nftables firewall manager")
|
||||
return nbnftables.Create(iface)
|
||||
return nbnftables.Create(iface, mtu)
|
||||
default:
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nadoo/ipset"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -40,19 +41,13 @@ type aclManager struct {
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
return &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
}
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
@@ -98,8 +93,8 @@ func (m *aclManager) AddPeerFiltering(
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
@@ -113,14 +108,18 @@ func (m *aclManager) AddPeerFiltering(
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := ipset.Create(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create ipset: %w", err)
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
}
|
||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
@@ -172,11 +171,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
|
||||
return fmt.Errorf("failed to delete ip from ipset: %w", err)
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
@@ -190,10 +194,7 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
|
||||
if err := ipset.Destroy(r.ipsetName); err != nil {
|
||||
log.Errorf("delete empty ipset: %v", err)
|
||||
}
|
||||
shouldDestroyIpset = true
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
@@ -206,6 +207,16 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
@@ -264,11 +275,19 @@ func (m *aclManager) cleanChains() error {
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := ipset.Flush(ipsetName); err != nil {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := ipset.Destroy(ipsetName); err != nil {
|
||||
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
@@ -368,8 +387,8 @@ func (m *aclManager) updateState() {
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
if ip.IsUnspecified() {
|
||||
matchByIP = false
|
||||
}
|
||||
|
||||
@@ -400,7 +419,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ""
|
||||
}
|
||||
|
||||
// Include action in the ipset name to prevent squashing rules with different actions
|
||||
actionSuffix := ""
|
||||
if action == firewall.ActionDrop {
|
||||
actionSuffix = "-drop"
|
||||
@@ -417,3 +435,61 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ type iFaceMapper interface {
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init iptables: %w", err)
|
||||
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(iptablesClient, wgIface)
|
||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -260,6 +261,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/nadoo/ipset"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
@@ -30,17 +30,20 @@ const (
|
||||
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
@@ -48,6 +51,9 @@ const (
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
@@ -77,16 +83,18 @@ type router struct {
|
||||
ipsetCounter *ipsetCounter
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
r := &router{
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
@@ -99,10 +107,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
|
||||
},
|
||||
)
|
||||
|
||||
if err := ipset.Init(); err != nil {
|
||||
return nil, fmt.Errorf("init ipset: %w", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
@@ -224,12 +228,12 @@ func (r *router) findSets(rule []string) []string {
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := ipset.AddPrefix(setName, prefix); err != nil {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
@@ -238,7 +242,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
@@ -392,6 +396,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
@@ -416,6 +421,7 @@ func (r *router) createContainers() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
@@ -438,6 +444,10 @@ func (r *router) createContainers() error {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -518,6 +528,35 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
"-j", chainRTMSSCLAMP,
|
||||
}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
|
||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||
}
|
||||
r.rules[jumpMSSClamp] = jumpRule
|
||||
|
||||
ruleOut := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mss),
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
|
||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||
}
|
||||
r.rules["mss-clamp-out"] = ruleOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
@@ -558,7 +597,7 @@ func (r *router) addJumpRules() error {
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
var table, chain string
|
||||
switch ruleKey {
|
||||
@@ -571,6 +610,9 @@ func (r *router) cleanJumpRules() error {
|
||||
case jumpNatPre:
|
||||
table = tableNat
|
||||
chain = chainPREROUTING
|
||||
case jumpMSSClamp:
|
||||
table = tableMangle
|
||||
chain = chainFORWARD
|
||||
default:
|
||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||
}
|
||||
@@ -869,8 +911,8 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
if err := r.addPrefixToIPSet(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
@@ -880,6 +922,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
ruleInfo := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = ruleInfo.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
@@ -899,3 +989,37 @@ func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
|
||||
func (r *router) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
// Now 5 rules:
|
||||
// 1. established rule forward in
|
||||
// 2. estbalished rule forward out
|
||||
// 3. jump rule to POST nat chain
|
||||
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 7. static return masquerade rule
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
// 10. jump rule to MSS clamping chain
|
||||
// 11. MSS clamping rule for outbound traffic
|
||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock)
|
||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -11,6 +12,7 @@ type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
ipt, err := Create(s.InterfaceState)
|
||||
mtu := s.InterfaceState.MTU
|
||||
if mtu == 0 {
|
||||
mtu = iface.DefaultMTU
|
||||
}
|
||||
ipt, err := Create(s.InterfaceState, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iptables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -100,6 +100,9 @@ type Manager interface {
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
//
|
||||
// Note: Callers should call Flush() after adding rules to ensure
|
||||
// they are applied to the kernel and rule handles are refreshed.
|
||||
AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
@@ -151,14 +154,20 @@ type Manager interface {
|
||||
|
||||
DisableRouting() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
// DeleteDNATRule deletes the outbound DNAT rule.
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -29,8 +29,6 @@ const (
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
// mask
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
nftRule: r.nftRule,
|
||||
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||
}
|
||||
|
||||
ruleStruct := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftRule: nftRule,
|
||||
// best effort mangle rule
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
|
||||
},
|
||||
)
|
||||
|
||||
return m.rConn.AddRule(&nftables.Rule{
|
||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":" + string(proto) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
@@ -19,13 +19,22 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
|
||||
tableNameNetbird = "netbird"
|
||||
// envTableName is the environment variable to override the table name
|
||||
envTableName = "NB_NFTABLES_TABLE"
|
||||
|
||||
tableNameFilter = "filter"
|
||||
chainNameInput = "INPUT"
|
||||
)
|
||||
|
||||
func getTableName() string {
|
||||
if name := os.Getenv(envTableName); name != "" {
|
||||
return name
|
||||
}
|
||||
return tableNameNetbird
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
@@ -44,16 +53,16 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface)
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
err := m.aclManager.createDefaultAllowRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chain == nil {
|
||||
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||
}
|
||||
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
m.applyAllowNetbirdRules(chain)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.resetNetbirdInputRules(); err != nil {
|
||||
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetNetbirdInputRules() error {
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
m.deleteNetbirdInputRules(chains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
m.deleteMatchingRules(rules)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
@@ -376,61 +315,40 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
}
|
||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||
continue
|
||||
}
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
// This test verifies rule insertion order in nftables peer ACLs
|
||||
// We add accept rule first, then deny rule to test ordering behavior
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -32,12 +33,17 @@ const (
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
@@ -63,9 +69,10 @@ type router struct {
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
r := &router{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
||||
func (r *router) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
// Reset cleans existing nftables filter table rules from the system
|
||||
func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameMangleForward,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 13,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 1,
|
||||
Mask: []byte{0x02},
|
||||
Xor: []byte{0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Exthdr{
|
||||
DestRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGt,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Exthdr{
|
||||
SourceRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameMangleForward],
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward rules", fw)
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptForwardRulesNftables()
|
||||
return r.acceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptForwardRulesIptables(ipt)
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables rule: %v", rule)
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesNftables() error {
|
||||
func (r *router) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
|
||||
},
|
||||
}
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
|
||||
inputRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
}
|
||||
r.conn.InsertRule(inputRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRules() error {
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptForwardRulesNftables()
|
||||
return r.removeAcceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptForwardRulesIptables(ipt)
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -1350,6 +1490,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -10,6 +11,7 @@ type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
nft, err := Create(s.InterfaceState)
|
||||
mtu := s.InterfaceState.MTU
|
||||
if mtu == 0 {
|
||||
mtu = iface.DefaultMTU
|
||||
}
|
||||
nft, err := Create(s.InterfaceState, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ type BaseConnTrack struct {
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
|
||||
DNATOrigPort atomic.Uint32
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
|
||||
@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
|
||||
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
|
||||
@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
|
||||
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
|
||||
if exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -27,7 +28,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
const (
|
||||
layerTypeAll = 0
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
@@ -36,6 +48,9 @@ const (
|
||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||
|
||||
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
||||
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
||||
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
@@ -109,6 +124,17 @@ type Manager struct {
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -122,19 +148,21 @@ type decoder struct {
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -142,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func parseCreateEnv() (bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding bool
|
||||
func parseCreateEnv() (bool, bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||
disableConntrack, err = strconv.ParseBool(val)
|
||||
@@ -162,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
|
||||
disableMSSClamping, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@@ -196,13 +230,19 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
mtu: mtu,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if !disableMSSClamping {
|
||||
m.mssClampEnabled = true
|
||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||
}
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
@@ -210,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding {
|
||||
if err := m.initForwarder(); err != nil {
|
||||
log.Errorf("failed to initialize forwarder: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
@@ -320,7 +357,7 @@ func (m *Manager) initForwarder() error {
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
|
||||
if err != nil {
|
||||
m.routingEnabled.Store(false)
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
@@ -626,11 +663,20 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
m.clampTCPMSS(packetData, d)
|
||||
}
|
||||
}
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
|
||||
return false
|
||||
@@ -674,14 +720,117 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
|
||||
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
|
||||
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
if !d.tcp.SYN {
|
||||
return false
|
||||
}
|
||||
if len(d.tcp.Options) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
mssOptionIndex := -1
|
||||
var currentMSS uint16
|
||||
for i, opt := range d.tcp.Options {
|
||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||
if currentMSS > m.mssClampValue {
|
||||
mssOptionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mssOptionIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||
tcpHeaderStart := ipHeaderSize
|
||||
tcpOptionsStart := tcpHeaderStart + 20
|
||||
|
||||
optOffset := tcpOptionsStart
|
||||
for j := 0; j < mssOptionIndex; j++ {
|
||||
switch d.tcp.Options[j].OptionType {
|
||||
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
|
||||
optOffset++
|
||||
default:
|
||||
optOffset += 2 + len(d.tcp.Options[j].OptionData)
|
||||
}
|
||||
}
|
||||
|
||||
mssValueOffset := optOffset + 2
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||
|
||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
|
||||
tcpLayer := packetData[tcpHeaderStart:]
|
||||
tcpLength := len(packetData) - tcpHeaderStart
|
||||
|
||||
tcpLayer[16] = 0
|
||||
tcpLayer[17] = 0
|
||||
|
||||
var pseudoSum uint32
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
var sum uint32 = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
if tcpLength%2 == 1 {
|
||||
sum += uint32(tcpLayer[tcpLength-1]) << 8
|
||||
}
|
||||
|
||||
for sum > 0xFFFF {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
|
||||
checksum := ^uint16(sum)
|
||||
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite UDP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite TCP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
}
|
||||
@@ -691,13 +840,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
@@ -759,10 +910,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
@@ -807,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
if m.shouldForward(d, dstIP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
@@ -1243,3 +1402,86 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
m.netstackServices[key] = struct{}{}
|
||||
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
|
||||
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
|
||||
}
|
||||
|
||||
// UnregisterNetstackService removes a service from the netstack registry
|
||||
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
delete(m.netstackServices, key)
|
||||
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
|
||||
}
|
||||
|
||||
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
|
||||
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
|
||||
switch protocol {
|
||||
case nftypes.TCP:
|
||||
return layers.LayerTypeTCP
|
||||
case nftypes.UDP:
|
||||
return layers.LayerTypeUDP
|
||||
case nftypes.ICMP:
|
||||
return layers.LayerTypeICMPv4
|
||||
default:
|
||||
return gopacket.LayerType(0) // Invalid/unknown
|
||||
}
|
||||
}
|
||||
|
||||
// shouldForward determines if a packet should be forwarded to the forwarder.
|
||||
// The forwarder handles routing packets to the native OS network stack.
|
||||
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
|
||||
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
// not enabled, never forward
|
||||
if !m.localForwarding {
|
||||
return false
|
||||
}
|
||||
|
||||
// netstack always needs to forward because it's lacking a native interface
|
||||
// exception for registered netstack services, those should go to netstack listeners
|
||||
if m.netstack {
|
||||
return !m.hasMatchingNetstackService(d)
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
return true
|
||||
}
|
||||
|
||||
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
|
||||
return false
|
||||
}
|
||||
|
||||
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
|
||||
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
|
||||
if len(d.decoded) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
var dstPort uint16
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
key := serviceKey{protocol: d.decoded[1], port: dstPort}
|
||||
m.netstackServiceMutex.RLock()
|
||||
_, exists := m.netstackServices[key]
|
||||
m.netstackServiceMutex.RUnlock()
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
// Set TCP flags
|
||||
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
||||
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
||||
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
||||
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
||||
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func BenchmarkRouteACLs(b *testing.B) {
|
||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||
|
||||
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
|
||||
// This shows the overhead difference between the common case (non-SYN packets, fast path)
|
||||
// and the rare case (SYN packets that need clamping, expensive path).
|
||||
func BenchmarkMSSClamping(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
description string
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
frequency string
|
||||
}{
|
||||
{
|
||||
name: "syn_needs_clamp",
|
||||
description: "SYN packet needing MSS clamping",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
frequency: "~0.1% of traffic - EXPENSIVE",
|
||||
},
|
||||
{
|
||||
name: "syn_no_clamp_needed",
|
||||
description: "SYN packet with already-small MSS",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
|
||||
},
|
||||
frequency: "~0.05% of traffic",
|
||||
},
|
||||
{
|
||||
name: "tcp_ack",
|
||||
description: "Non-SYN TCP packet (ACK, data transfer)",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
frequency: "~60-70% of traffic - FAST PATH",
|
||||
},
|
||||
{
|
||||
name: "tcp_psh_ack",
|
||||
description: "TCP data packet (PSH+ACK)",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
},
|
||||
frequency: "~10-20% of traffic - FAST PATH",
|
||||
},
|
||||
{
|
||||
name: "udp",
|
||||
description: "UDP packet",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||
},
|
||||
frequency: "~20-30% of traffic - FAST PATH",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
|
||||
// for the common case (non-SYN TCP packets).
|
||||
func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
}{
|
||||
{
|
||||
name: "disabled_tcp_ack",
|
||||
enabled: false,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled_tcp_ack",
|
||||
enabled: true,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "disabled_syn_needs_clamp",
|
||||
enabled: false,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled_syn_needs_clamp",
|
||||
enabled: true,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = sc.enabled
|
||||
if sc.enabled {
|
||||
manager.mssClampValue = 1240
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
|
||||
func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
}{
|
||||
{
|
||||
name: "tcp_ack_fast_path",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "syn_needs_clamp",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "udp_fast_path",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ip := &layers.IPv4{
|
||||
Version: 4,
|
||||
IHL: 5,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
Seq: 1000,
|
||||
Window: 65535,
|
||||
}
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -17,9 +18,11 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -66,7 +69,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -86,7 +89,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -119,7 +122,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -215,7 +218,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
@@ -265,7 +268,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -304,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -367,7 +370,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false, flowLogger)
|
||||
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -413,7 +416,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close()
|
||||
@@ -495,7 +498,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -522,7 +525,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
@@ -729,7 +732,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -815,7 +818,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -923,3 +926,327 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSClamping(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := net.ParseIP("100.10.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
|
||||
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
|
||||
highMSS := uint16(1460)
|
||||
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
})
|
||||
|
||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||
lowMSS := uint16(1200)
|
||||
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
|
||||
})
|
||||
|
||||
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
|
||||
highMSS := uint16(1460)
|
||||
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
})
|
||||
|
||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
|
||||
})
|
||||
}
|
||||
|
||||
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
Window: 65535,
|
||||
Options: []layers.TCPOption{
|
||||
{
|
||||
OptionType: layers.TCPOptionKindMSS,
|
||||
OptionLength: 4,
|
||||
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
ACK: true,
|
||||
Window: 65535,
|
||||
Options: []layers.TCPOption{
|
||||
{
|
||||
OptionType: layers.TCPOptionKindMSS,
|
||||
OptionLength: 4,
|
||||
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
Window: 65535,
|
||||
}
|
||||
|
||||
if flags&uint16(conntrack.TCPSyn) != 0 {
|
||||
tcpLayer.SYN = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPAck) != 0 {
|
||||
tcpLayer.ACK = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPFin) != 0 {
|
||||
tcpLayer.FIN = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPRst) != 0 {
|
||||
tcpLayer.RST = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPPush) != 0 {
|
||||
tcpLayer.PSH = true
|
||||
}
|
||||
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestShouldForward(t *testing.T) {
|
||||
// Set up test addresses
|
||||
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||
|
||||
// Create test manager with mock interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
// Set the mock to return our test WG IP
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Helper to create decoder with TCP packet
|
||||
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||
ipv4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: net.ParseIP("192.168.1.100"),
|
||||
DstIP: wgIP.AsSlice(),
|
||||
}
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: 54321,
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localForwarding bool
|
||||
netstack bool
|
||||
dstIP netip.Addr
|
||||
serviceRegistered bool
|
||||
servicePort uint16
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no local forwarding",
|
||||
localForwarding: false,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should never forward when local forwarding disabled",
|
||||
},
|
||||
{
|
||||
name: "traffic to other local interface",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: otherIP,
|
||||
expected: true,
|
||||
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||
},
|
||||
{
|
||||
name: "traffic to NetBird IP, no netstack",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners (final return false path)",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, no service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: true,
|
||||
description: "should forward when in netstack mode with no matching service",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, with service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
serviceRegistered: true,
|
||||
servicePort: 22,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners when service is registered",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
manager.localForwarding = tt.localForwarding
|
||||
manager.netstack = tt.netstack
|
||||
|
||||
// Register service if needed
|
||||
if tt.serviceRegistered {
|
||||
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
}
|
||||
|
||||
// Create decoder for the test
|
||||
decoder := createTCPDecoder(tt.servicePort)
|
||||
if !tt.serviceRegistered {
|
||||
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||
}
|
||||
|
||||
// Test the method
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
|
||||
@@ -49,7 +49,7 @@ type idleConn struct {
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
|
||||
@@ -50,6 +50,8 @@ type logMessage struct {
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
}
|
||||
}
|
||||
|
||||
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = (*buf)[:0]
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
|
||||
case 6:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
|
||||
case 7:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
|
||||
case 8:
|
||||
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
|
||||
}
|
||||
|
||||
*buf = append(*buf, formatted...)
|
||||
@@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error {
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -13,6 +15,21 @@ import (
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
var (
|
||||
errInvalidIPHeaderLength = errors.New("invalid IP header length")
|
||||
)
|
||||
|
||||
const (
|
||||
// Port offsets in TCP/UDP headers
|
||||
sourcePortOffset = 0
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
)
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// icmpChecksum calculates ICMP checksum.
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// biDNATMap maintains bidirectional DNAT mappings.
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
// portDNATRule represents a port-specific DNAT rule.
|
||||
type portDNATRule struct {
|
||||
protocol gopacket.LayerType
|
||||
origPort uint16
|
||||
targetPort uint16
|
||||
targetIP netip.Addr
|
||||
}
|
||||
|
||||
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
|
||||
}
|
||||
}
|
||||
|
||||
// set adds a bidirectional DNAT mapping between original and translated addresses.
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
// delete removes a bidirectional DNAT mapping for the given original address.
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
// getTranslated returns the translated address for a given original address.
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// getOriginal returns the original address for a given translated address.
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
if !originalAddr.IsValid() {
|
||||
return fmt.Errorf("invalid original IP address")
|
||||
}
|
||||
if !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid translated IP address")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
// getDNATTranslation returns the translated address if a mapping exists.
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
// findReverseDNATMapping finds original address for return traffic.
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets.
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error1("Failed to rewrite packet destination: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic.
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error1("Failed to rewrite packet source: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
if !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateICMPChecksum recalculates ICMP checksum after packet modification.
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
// DeleteDNATRule deletes outbound DNAT rule.
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// addPortRedirection adds a port redirection rule.
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
rule := portDNATRule{
|
||||
protocol: protocol,
|
||||
origPort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
}
|
||||
|
||||
m.portDNATRules = append(m.portDNATRules, rule)
|
||||
m.portDNATEnabled.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
layerType = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
layerType = layers.LayerTypeUDP
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// removePortRedirection removes a port redirection rule.
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
||||
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
|
||||
})
|
||||
|
||||
if len(m.portDNATRules) == 0 {
|
||||
m.portDNATEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
layerType = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
layerType = layers.LayerTypeUDP
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
|
||||
|
||||
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATRules {
|
||||
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.origPort != port {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
|
||||
portStart := tcpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
portStart := udpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
checksumOffset := udpStart + 6
|
||||
if len(packetData) >= udpStart+8 {
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
if oldChecksum != 0 {
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -414,3 +415,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPortDNAT measures the performance of port DNAT operations
|
||||
func BenchmarkPortDNAT(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
useMatchPort bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "TCP inbound port DNAT translation (22 → 22022)",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "UDP inbound port DNAT translation (5353 → 22054)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
var origPort, targetPort, testPort uint16
|
||||
if sc.proto == layers.IPProtocolTCP {
|
||||
origPort, targetPort = 22, 22022
|
||||
} else {
|
||||
origPort, targetPort = 5353, 22054
|
||||
}
|
||||
|
||||
if sc.useMatchPort {
|
||||
testPort = origPort
|
||||
} else {
|
||||
testPort = 443 // Different port
|
||||
}
|
||||
|
||||
// Setup port DNAT mapping if needed
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Pre-establish inbound connection for outbound reverse test
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
|
||||
manager.filterInbound(inboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Benchmark inbound DNAT translation
|
||||
b.Run("inbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time
|
||||
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
b.Run("outbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh return packet (from target port)
|
||||
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestPortDNATBasic tests basic port DNAT functionality
|
||||
func TestPortDNATBasic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rule for peer B (the target)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d := parsePacket(t, packetAtoB)
|
||||
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
|
||||
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||
|
||||
// Verify port was translated to 22022
|
||||
d = parsePacket(t, packetAtoB)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||
|
||||
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
|
||||
// (prevents double NAT - original port stored in conntrack)
|
||||
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||
d2 := parsePacket(t, returnPacket)
|
||||
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
|
||||
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
|
||||
}
|
||||
|
||||
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||
func TestPortDNATMultipleRules(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rules for both peers
|
||||
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test traffic to peer B gets translated
|
||||
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d1 := parsePacket(t, packetToB)
|
||||
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
|
||||
require.True(t, translatedToB, "Traffic to peer B should be translated")
|
||||
d1 = parsePacket(t, packetToB)
|
||||
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
|
||||
|
||||
// Test traffic to peer A gets translated
|
||||
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||
d2 := parsePacket(t, packetToA)
|
||||
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
|
||||
require.True(t, translatedToA, "Traffic to peer A should be translated")
|
||||
d2 = parsePacket(t, packetToA)
|
||||
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -15,7 +17,7 @@ import (
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -99,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -143,3 +145,111 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
sourcePort uint16
|
||||
targetPort uint16
|
||||
}{
|
||||
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
|
||||
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
|
||||
d := parsePacket(t, inboundPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
|
||||
require.True(t, translated, "Inbound packet should be translated")
|
||||
|
||||
d = parsePacket(t, inboundPacket)
|
||||
var dstPort uint16
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.IPProtocolUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
|
||||
|
||||
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
|
||||
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
|
||||
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
|
||||
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
d := parsePacket(t, packet)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
d = parsePacket(t, packet)
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.IPProtocolUDP:
|
||||
return firewall.ProtocolUDP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,25 +16,33 @@ type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageInboundPortDNAT
|
||||
StageInbound1to1NAT
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
StageOutbound1to1NAT
|
||||
StageOutboundPortReverse
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageReceived: "Received",
|
||||
StageInboundPortDNAT: "Inbound Port DNAT",
|
||||
StageInbound1to1NAT: "Inbound 1:1 NAT",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageOutbound1to1NAT: "Outbound 1:1 NAT",
|
||||
StageOutboundPortReverse: "Outbound DNAT Reverse",
|
||||
}[s]
|
||||
}
|
||||
|
||||
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
|
||||
m.handleOutboundDNAT(trace, packetData, d)
|
||||
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
}
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
trace.DestinationPort = m.getDestPort(d)
|
||||
}
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
protocol := d.decoded[1]
|
||||
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
var originalPort uint16
|
||||
if protocol == layers.LayerTypeTCP {
|
||||
originalPort = uint16(d.tcp.DstPort)
|
||||
} else {
|
||||
originalPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
|
||||
if translated {
|
||||
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
|
||||
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
|
||||
|
||||
protoStr := "TCP"
|
||||
if protocol == layers.LayerTypeUDP {
|
||||
protoStr = "UDP"
|
||||
}
|
||||
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
|
||||
trace.AddResult(StageInboundPortDNAT, msg, true)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
|
||||
trace.AddResult(StageInbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
|
||||
m.traceOutbound1to1NAT(trace, packetData, d)
|
||||
m.traceOutboundPortReverse(trace, packetData, d)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatMappings[dstIP]
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
|
||||
trace.AddResult(StageOutbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
var origPort uint16
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeTCP:
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
srcPort := uint16(d.udp.SrcPort)
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
default:
|
||||
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) getDestPort(d *decoder) uint16 {
|
||||
if len(d.decoded) < 2 {
|
||||
return 0
|
||||
}
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
@@ -104,6 +105,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -126,6 +129,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -153,6 +158,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -179,6 +186,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -204,6 +213,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -228,6 +239,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -246,6 +259,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -264,6 +279,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
@@ -287,6 +304,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
@@ -301,6 +320,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
@@ -319,6 +340,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -340,6 +363,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -362,6 +387,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -382,6 +409,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -406,6 +435,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
@@ -29,7 +30,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
// for js, the outer websocket layer takes care of tls
|
||||
if tlsEnabled && runtime.GOOS != "js" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
@@ -37,9 +39,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}
|
||||
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
||||
InsecureSkipVerify: runtime.GOOS == "js",
|
||||
RootCAs: certPool,
|
||||
RootCAs: certPool,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -58,8 +58,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("DialContext error: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
|
||||
@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the existing peer to preserve its allowed IPs
|
||||
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer: %w", err)
|
||||
}
|
||||
|
||||
removePeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
||||
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
||||
}
|
||||
|
||||
//Re-add the peer without the endpoint but same AllowedIPs
|
||||
reAddPeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
AllowedIPs: existingPeer.AllowedIPs,
|
||||
ReplaceAllowedIPs: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
||||
return fmt.Errorf(
|
||||
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
||||
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer key: %w", err)
|
||||
}
|
||||
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get IPC config: %w", err)
|
||||
}
|
||||
|
||||
// Parse current status to get allowed IPs for the peer
|
||||
stats, err := parseStatus(c.deviceName, ipcStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse IPC config: %w", err)
|
||||
}
|
||||
|
||||
var allowedIPs []net.IPNet
|
||||
found := false
|
||||
for _, peer := range stats.Peers {
|
||||
if peer.PublicKey == peerKey {
|
||||
allowedIPs = peer.AllowedIPs
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("peer %s not found", peerKey)
|
||||
}
|
||||
|
||||
// remove the peer from the WireGuard configuration
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
||||
}
|
||||
|
||||
// Build the peer config
|
||||
peer = wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: allowedIPs,
|
||||
}
|
||||
|
||||
config = wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
||||
return fmt.Errorf("remove endpoint address: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,4 +23,5 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *WGTunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
func routesToString(routes []string) string {
|
||||
return strings.Join(routes, ";")
|
||||
}
|
||||
|
||||
@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns nil for kernel mode devices
|
||||
func (t *TunKernelDevice) GetICEBind() EndpointManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ type Bind interface {
|
||||
conn.Bind
|
||||
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
ActivityRecorder() *bind.ActivityRecorder
|
||||
EndpointManager
|
||||
}
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
|
||||
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||
return t.net
|
||||
}
|
||||
|
||||
// GetICEBind returns the bind instance
|
||||
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
|
||||
return t.bind
|
||||
}
|
||||
|
||||
@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
13
client/iface/device/endpoint_manager.go
Normal file
13
client/iface/device/endpoint_manager.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
|
||||
// Implemented by bind.ICEBind and bind.RelayBindJS.
|
||||
type EndpointManager interface {
|
||||
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
|
||||
RemoveEndpoint(fakeIP netip.Addr)
|
||||
}
|
||||
@@ -21,4 +21,5 @@ type WGConfigurer interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
RemoveEndpointAddress(peerKey string) error
|
||||
}
|
||||
|
||||
@@ -21,4 +21,5 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.tun == nil {
|
||||
return nil
|
||||
}
|
||||
return w.tun.GetICEBind()
|
||||
}
|
||||
|
||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||
func (w *WGIface) IsUserspaceBind() bool {
|
||||
return w.userspaceBind
|
||||
@@ -148,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Removing endpoint address: %s", peerKey)
|
||||
return w.configurer.RemoveEndpointAddress(peerKey)
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -9,13 +10,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// keep darwin compatibility
|
||||
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
addr := "100.64.0.1/8"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
||||
func Test_CreateInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||
wgIP := "10.99.99.1/32"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||
wgIP := "10.99.99.2/32"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||
wgIP := "10.99.99.5/30"
|
||||
wgPort := 33100
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
||||
func Test_UpdatePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.9/30"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
|
||||
func Test_RemovePeer(t *testing.T) {
|
||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||
wgIP := "10.99.99.13/30"
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
peer2wgPort := 33200
|
||||
|
||||
keepAlive := 1 * time.Second
|
||||
newNet, err := stdnet.NewNet()
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
|
||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||
|
||||
newNet, err = stdnet.NewNet()
|
||||
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package udpmux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -12,8 +13,9 @@ import (
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/pion/transport/v3"
|
||||
"github.com/pion/transport/v3/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -29,11 +28,6 @@ type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||
}
|
||||
|
||||
type protoMatch struct {
|
||||
ips map[string]int
|
||||
policyID []byte
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
@@ -86,30 +80,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
|
||||
// if TCP protocol rules not squashed and SSH enabled
|
||||
// we add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP 22 port).
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||
})
|
||||
}
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
// we have old version of management without rules handling, we should allow all traffic
|
||||
@@ -368,145 +339,6 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
||||
//
|
||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
||||
// but other has port definitions or has drop policy.
|
||||
func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
totalIPs++
|
||||
}
|
||||
}
|
||||
|
||||
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
// But we zeroed the IP's for protocol if:
|
||||
// 1. Any of the rule has DROP action type.
|
||||
// 2. Any of rule contains Port.
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||
|
||||
if hasPortRestrictions {
|
||||
// Don't squash rules with port restrictions
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
// store the first encountered PolicyID for this protocol
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
}
|
||||
|
||||
// special case, when we receive this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
ipset := protocols[r.Protocol].ips
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
}
|
||||
ipset[r.PeerIP] = i
|
||||
}
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
// calculate squash for different directions
|
||||
if r.Direction == mgmProto.RuleDirection_IN {
|
||||
addRuleToCalculationMap(i, r, in)
|
||||
} else {
|
||||
addRuleToCalculationMap(i, r, out)
|
||||
}
|
||||
}
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.RuleProtocol{
|
||||
mgmProto.RuleProtocol_ALL,
|
||||
mgmProto.RuleProtocol_ICMP,
|
||||
mgmProto.RuleProtocol_TCP,
|
||||
mgmProto.RuleProtocol_UDP,
|
||||
}
|
||||
|
||||
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
match, ok := matches[protocol]
|
||||
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
||||
// don't squash if :
|
||||
// 1. Rules not cover all peers in the network
|
||||
// 2. Rules cover only one peer in the network.
|
||||
continue
|
||||
}
|
||||
|
||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: direction,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: protocol,
|
||||
PolicyID: match.policyID,
|
||||
})
|
||||
squashedProtocols[protocol] = struct{}{}
|
||||
|
||||
if protocol == mgmProto.RuleProtocol_ALL {
|
||||
// if we have ALL traffic type squashed rule
|
||||
// it allows all other type of traffic, so we can stop processing
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
squash(in, mgmProto.RuleDirection_IN)
|
||||
squash(out, mgmProto.RuleDirection_OUT)
|
||||
|
||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
return squashedRules, squashedProtocols
|
||||
}
|
||||
|
||||
if len(squashedRules) == 0 {
|
||||
return networkMap.FirewallRules, squashedProtocols
|
||||
}
|
||||
|
||||
var rules []*mgmProto.FirewallRule
|
||||
// filter out rules which was squashed from final list
|
||||
// if we also have other not squashed rules.
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
||||
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
return append(rules, squashedRules...), squashedProtocols
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
@@ -188,492 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, 2, len(rules))
|
||||
|
||||
r := rules[0]
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
|
||||
r = rules[1]
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []*mgmProto.FirewallRule
|
||||
expectedCount int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "should not squash rules with port ranges",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with specific ports",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with legacy port field",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with legacy port field should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with DROP action",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with DROP action should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should squash rules without port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||
},
|
||||
{
|
||||
name: "mixed rules should not squash protocol with port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "TCP should not be squashed because one rule has port restrictions",
|
||||
},
|
||||
{
|
||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
// TCP rules with port restrictions - should NOT be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
// UDP rules without port restrictions - SHOULD be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: tt.rules,
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
|
||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||
|
||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||
if tt.expectedCount == 1 {
|
||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -757,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
},
|
||||
},
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
expectedRules := 3
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
}
|
||||
|
||||
@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
if d.providerConfig.LoginHint != "" {
|
||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||
if deviceCode.VerificationURI != "" {
|
||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||
}
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func appendLoginHint(uri, loginHint string) string {
|
||||
if uri == "" || loginHint == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||
return uri
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("login_hint", loginHint)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
return parsedURL.String()
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
|
||||
@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
if p.providerConfig.LoginHint != "" {
|
||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
@@ -189,17 +192,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token,
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
if authErrorDesc != "" {
|
||||
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
||||
}
|
||||
return nil, fmt.Errorf("authentication failed: %s", authError)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on the state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
return nil, fmt.Errorf("authentication failed: Invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
return nil, fmt.Errorf("authentication failed: missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
@@ -228,7 +234,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo,
|
||||
}
|
||||
|
||||
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
||||
}
|
||||
|
||||
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -34,7 +35,6 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -52,7 +52,6 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
@@ -63,8 +62,8 @@ func NewConnectClient(
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||
return c.run(MobileDependency{}, runningChan)
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
// RunOnAndroid with main logic on mobile system
|
||||
@@ -83,7 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) RunOniOS(
|
||||
@@ -101,10 +100,10 @@ func (c *ConnectClient) RunOniOS(
|
||||
DnsManager: dnsManager,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
rec := c.statusRecorder
|
||||
@@ -247,7 +246,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -271,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, c.config)
|
||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
@@ -289,15 +288,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.engine = nil
|
||||
}
|
||||
c.engineMutex.Unlock()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
@@ -382,19 +384,12 @@ func (c *ConnectClient) Status() StatusType {
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
}
|
||||
c.engineMutex.Lock()
|
||||
defer c.engineMutex.Unlock()
|
||||
|
||||
if c.engine == nil {
|
||||
return nil
|
||||
}
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -414,26 +409,31 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logPath string) (*EngineConfig, error) {
|
||||
nm := false
|
||||
if config.NetworkMonitor != nil {
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||
@@ -444,7 +444,10 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
|
||||
ProfileConfig: config,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
@@ -519,6 +522,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,8 +27,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -44,10 +46,12 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
block.prof: Block profiling information.
|
||||
@@ -184,6 +188,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||
|
||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||
|
||||
DNS Configuration
|
||||
The debug bundle includes platform-specific DNS configuration files:
|
||||
|
||||
resolv.conf (Unix systems):
|
||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||
- Includes nameserver entries, search domains, and resolver options
|
||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||
|
||||
scutil_dns.txt (macOS only):
|
||||
- Contains detailed DNS configuration from scutil --dns
|
||||
- Shows DNS configuration for all network interfaces
|
||||
- Includes search domains, nameservers, and DNS resolver settings
|
||||
- All IP addresses and domain names are anonymized
|
||||
`
|
||||
|
||||
const (
|
||||
@@ -202,10 +220,9 @@ type BundleGenerator struct {
|
||||
internalConfig *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logFile string
|
||||
logPath string
|
||||
|
||||
anonymize bool
|
||||
clientStatus string
|
||||
includeSystemInfo bool
|
||||
logFileCount uint32
|
||||
|
||||
@@ -214,7 +231,6 @@ type BundleGenerator struct {
|
||||
|
||||
type BundleConfig struct {
|
||||
Anonymize bool
|
||||
ClientStatus string
|
||||
IncludeSystemInfo bool
|
||||
LogFileCount uint32
|
||||
}
|
||||
@@ -223,7 +239,7 @@ type GeneratorDependencies struct {
|
||||
InternalConfig *profilemanager.Config
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogFile string
|
||||
LogPath string
|
||||
}
|
||||
|
||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||
@@ -239,10 +255,9 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
internalConfig: deps.InternalConfig,
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logFile: deps.LogFile,
|
||||
logPath: deps.LogPath,
|
||||
|
||||
anonymize: cfg.Anonymize,
|
||||
clientStatus: cfg.ClientStatus,
|
||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||
logFileCount: logFileCount,
|
||||
}
|
||||
@@ -288,13 +303,6 @@ func (g *BundleGenerator) createArchive() error {
|
||||
return fmt.Errorf("add status: %w", err)
|
||||
}
|
||||
|
||||
if g.statusRecorder != nil {
|
||||
status := g.statusRecorder.GetFullStatus()
|
||||
seedFromStatus(g.anonymizer, &status)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
|
||||
if err := g.addConfig(); err != nil {
|
||||
log.Errorf("failed to add config to debug bundle: %v", err)
|
||||
}
|
||||
@@ -327,7 +335,7 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logFile != "" && !slices.Contains(util.SpecialLogs, g.logFile) {
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
@@ -357,6 +365,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
if err := g.addFirewallRules(); err != nil {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addReadme() error {
|
||||
@@ -368,11 +380,26 @@ func (g *BundleGenerator) addReadme() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStatus() error {
|
||||
if status := g.clientStatus; status != "" {
|
||||
statusReader := strings.NewReader(status)
|
||||
if g.statusRecorder != nil {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
fullStatus := g.statusRecorder.GetFullStatus()
|
||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
||||
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
|
||||
|
||||
statusReader := strings.NewReader(statusOutput)
|
||||
if err := g.addFileToZip(statusReader, "status.txt"); err != nil {
|
||||
return fmt.Errorf("add status file to zip: %w", err)
|
||||
}
|
||||
seedFromStatus(g.anonymizer, &fullStatus)
|
||||
} else {
|
||||
log.Debugf("no status recorder available for seeding")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -433,6 +460,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRoot != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
|
||||
}
|
||||
if g.internalConfig.EnableSSHSFTP != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
|
||||
}
|
||||
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
@@ -564,6 +603,8 @@ func (g *BundleGenerator) addStateFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Adding state file from: %s", path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@@ -628,14 +669,14 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addLogfile() error {
|
||||
if g.logFile == "" {
|
||||
if g.logPath == "" {
|
||||
log.Debugf("skipping empty log file in debug bundle")
|
||||
return nil
|
||||
}
|
||||
|
||||
logDir := filepath.Dir(g.logFile)
|
||||
logDir := filepath.Dir(g.logPath)
|
||||
|
||||
if err := g.addSingleLogfile(g.logFile, clientLogFile); err != nil {
|
||||
if err := g.addSingleLogfile(g.logPath, clientLogFile); err != nil {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
|
||||
53
client/internal/debug/debug_darwin.go
Normal file
53
client/internal/debug/debug_darwin.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addScutilDNS(); err != nil {
|
||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addScutilDNS() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(output)) == 0 {
|
||||
return fmt.Errorf("no scutil DNS output")
|
||||
}
|
||||
|
||||
content := string(output)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,3 +5,7 @@ package debug
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
16
client/internal/debug/debug_nondarwin.go
Normal file
16
client/internal/debug/debug_nondarwin.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build unix && !darwin && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
client/internal/debug/debug_nonunix.go
Normal file
7
client/internal/debug/debug_nonunix.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !unix
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
29
client/internal/debug/debug_unix.go
Normal file
29
client/internal/debug/debug_unix.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build unix && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
func (g *BundleGenerator) addResolvConf() error {
|
||||
data, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
101
client/internal/debug/upload.go
Normal file
101
client/internal/debug/upload.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const maxBundleUploadSize = 50 * 1024 * 1024
|
||||
|
||||
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
|
||||
response, err := getUploadURL(ctx, url, managementURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = upload(ctx, filePath, response)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Key, nil
|
||||
}
|
||||
|
||||
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
|
||||
fileData, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
|
||||
defer fileData.Close()
|
||||
|
||||
stat, err := fileData.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
|
||||
if stat.Size() > maxBundleUploadSize {
|
||||
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create PUT request: %w", err)
|
||||
}
|
||||
|
||||
req.ContentLength = stat.Size()
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
putResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload failed: %v", err)
|
||||
}
|
||||
defer putResp.Body.Close()
|
||||
|
||||
if putResp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(putResp.Body)
|
||||
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
|
||||
id := getURLHash(managementURL)
|
||||
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GET request: %w", err)
|
||||
}
|
||||
|
||||
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(getReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
urlBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
var response types.GetURLResponse
|
||||
if err := json.Unmarshal(urlBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func getURLHash(url string) string {
|
||||
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
|
||||
fileContent := []byte("test file content")
|
||||
err := os.WriteFile(file, fileContent, 0640)
|
||||
require.NoError(t, err)
|
||||
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
|
||||
require.NoError(t, err)
|
||||
id := getURLHash(testURL)
|
||||
require.Contains(t, key, id+"/")
|
||||
@@ -14,6 +14,9 @@ type WGIface interface {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addWgShow() error {
|
||||
if g.statusRecorder == nil {
|
||||
return fmt.Errorf("no status recorder available for wg show")
|
||||
}
|
||||
result, err := g.statusRecorder.PeersStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
var err error
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
)
|
||||
|
||||
err = s.recordSystemDNSSettings(true)
|
||||
if err != nil {
|
||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
searchDomains = append(searchDomains, "\"\"")
|
||||
err = s.addLocalDNS()
|
||||
if err != nil {
|
||||
log.Infof("failed to enable split DNS")
|
||||
if err := s.addLocalDNS(); err != nil {
|
||||
log.Warnf("failed to add local DNS: %v", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
}
|
||||
|
||||
for _, dConf := range config.Domains {
|
||||
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) string() string {
|
||||
return "scutil"
|
||||
}
|
||||
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
func (s *systemConfigurator) addLocalDNS() error {
|
||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||
log.Errorf("Unable to get system DNS configuration")
|
||||
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
||||
}
|
||||
}
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
|
||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
||||
}
|
||||
} else {
|
||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||
log.Info("Not enabling local DNS server")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
111
client/internal/dns/host_darwin_test.go
Normal file
111
client/internal/dns/host_darwin_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !ios
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
stateFile := filepath.Join(tmpDir, "state.json")
|
||||
|
||||
sm := statemanager.New(stateFile)
|
||||
sm.RegisterState(&ShutdownState{})
|
||||
sm.Start()
|
||||
defer func() {
|
||||
require.NoError(t, sm.Stop(context.Background()))
|
||||
}()
|
||||
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
config := HostDNSConfig{
|
||||
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
||||
ServerPort: 53,
|
||||
RouteAll: true,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "example.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err := configurator.applyDNSConfig(config, sm)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
defer func() {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
t.Logf("Key %s exists before cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
sm2 := statemanager.New(stateFile)
|
||||
sm2.RegisterState(&ShutdownState{})
|
||||
err = sm2.LoadState(&ShutdownState{})
|
||||
require.NoError(t, err)
|
||||
|
||||
state := sm2.GetState(&ShutdownState{})
|
||||
if state == nil {
|
||||
t.Skip("State not saved, skipping cleanup test")
|
||||
}
|
||||
|
||||
shutdownState, ok := state.(*ShutdownState)
|
||||
require.True(t, ok)
|
||||
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if strings.Contains(string(output), "No such key") {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return !strings.Contains(string(output), "No such key"), nil
|
||||
}
|
||||
|
||||
func removeTestDNSKey(key string) error {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
||||
_, err := cmd.CombinedOutput()
|
||||
return err
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
r.updateState(stateManager)
|
||||
|
||||
var searchDomains, matchDomains []string
|
||||
for _, dConf := range config.Domains {
|
||||
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
log.Errorf("cleanup old dns match policies: %s", err)
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
if err != nil {
|
||||
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
return fmt.Errorf("remove dns match policies: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
r.updateState(stateManager)
|
||||
|
||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
|
||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||
}
|
||||
|
||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||
}
|
||||
defer closer(regKey)
|
||||
|
||||
|
||||
102
client/internal/dns/host_windows_test.go
Normal file
102
client/internal/dns/host_windows_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
// when the number of match domains decreases between configuration changes.
|
||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping registry integration test in short mode")
|
||||
}
|
||||
|
||||
defer cleanupRegistryKeys(t)
|
||||
cleanupRegistryKeys(t)
|
||||
|
||||
testIP := netip.MustParseAddr("100.64.0.1")
|
||||
|
||||
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||
require.NoError(t, err, "Should create test interface registry key")
|
||||
testKey.Close()
|
||||
defer func() {
|
||||
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
}
|
||||
|
||||
config5 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
{Domain: "domain3.com", MatchOnly: true},
|
||||
{Domain: "domain4.com", MatchOnly: true},
|
||||
{Domain: "domain5.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config5, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all 5 entries exist
|
||||
for i := 0; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||
}
|
||||
|
||||
config2 := HostDNSConfig{
|
||||
ServerIP: testIP,
|
||||
Domains: []DomainConfig{
|
||||
{Domain: "domain1.com", MatchOnly: true},
|
||||
{Domain: "domain2.com", MatchOnly: true},
|
||||
},
|
||||
}
|
||||
|
||||
err = cfg.applyDNSConfig(config2, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first 2 entries exist
|
||||
for i := 0; i < 2; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||
}
|
||||
|
||||
// Verify entries 2-4 are cleaned up
|
||||
for i := 2; i < 5; i++ {
|
||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||
}
|
||||
}
|
||||
|
||||
func registryKeyExists(path string) (bool, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if err == registry.ErrNotExist {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
k.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func cleanupRegistryKeys(*testing.T) {
|
||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||
_ = cfg.removeDNSMatchPolicies()
|
||||
}
|
||||
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
shutdownWg sync.WaitGroup
|
||||
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
||||
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
||||
disableSys bool
|
||||
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.ctxCancel()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
defer s.shutdownWg.Done()
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
for n, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
privKey, _ := wgtypes.GenerateKey()
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
return
|
||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
||||
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, err
|
||||
@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -31,6 +31,7 @@ const (
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
CreatedKeys []string
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create host manager: %w", err)
|
||||
}
|
||||
|
||||
for _, key := range s.CreatedKeys {
|
||||
manager.createdKeys[key] = struct{}{}
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||
}
|
||||
|
||||
78
client/internal/dnsfwd/cache.go
Normal file
78
client/internal/dnsfwd/cache.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
ip4Addrs []netip.Addr
|
||||
ip6Addrs []netip.Addr
|
||||
}
|
||||
|
||||
func newCache() *cache {
|
||||
return &cache{
|
||||
records: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.records[normalizeDomain(domain)]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
return slices.Clone(entry.ip4Addrs), true
|
||||
case dns.TypeAAAA:
|
||||
return slices.Clone(entry.ip6Addrs), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
norm := normalizeDomain(domain)
|
||||
entry, exists := c.records[norm]
|
||||
if !exists {
|
||||
entry = &cacheEntry{}
|
||||
c.records[norm] = entry
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
entry.ip4Addrs = slices.Clone(addrs)
|
||||
case dns.TypeAAAA:
|
||||
entry.ip6Addrs = slices.Clone(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// unset removes cached entries for the given domain and request type.
|
||||
func (c *cache) unset(domain string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.records, normalizeDomain(domain))
|
||||
}
|
||||
|
||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||
// lowercase and fully-qualified (with trailing dot).
|
||||
func normalizeDomain(domain string) string {
|
||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||
return dns.Fqdn(strings.ToLower(domain))
|
||||
}
|
||||
85
client/internal/dnsfwd/cache_test.go
Normal file
85
client/internal/dnsfwd/cache_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||
t.Helper()
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse addr %s: %v", s, err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestCacheNormalization(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
// Mixed case, without trailing dot
|
||||
domainInput := "ExAmPlE.CoM"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||
|
||||
// Lookup with lower, with trailing dot
|
||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Lookup with different casing again
|
||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSeparateTypes(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
domain := "test.local"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||
|
||||
c.set(domain, 1 /* A */, ipv4)
|
||||
c.set(domain, 28 /* AAAA */, ipv6)
|
||||
|
||||
got4, ok4 := c.get(domain, 1)
|
||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||
}
|
||||
|
||||
got6, ok6 := c.get(domain, 28)
|
||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||
c := newCache()
|
||||
domain := "clone.test"
|
||||
|
||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||
c.set(domain, 1, src)
|
||||
|
||||
// Mutate source slice; cache should be unaffected
|
||||
src[0] = mustAddr(t, "9.9.9.9")
|
||||
|
||||
got, ok := c.get(domain, 1)
|
||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Mutate returned slice; internal cache should remain unchanged
|
||||
got[0] = mustAddr(t, "4.4.4.4")
|
||||
got2, ok2 := c.get(domain, 1)
|
||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := newCache()
|
||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -33,7 +34,7 @@ type firewaller interface {
|
||||
}
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
listenAddress netip.AddrPort
|
||||
ttl uint32
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -46,9 +47,12 @@ type DNSForwarder struct {
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
|
||||
wgIface wgIface
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
@@ -56,30 +60,47 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
||||
var netstackNet *netstack.Net
|
||||
if f.wgIface != nil {
|
||||
netstackNet = f.wgIface.GetNet()
|
||||
}
|
||||
|
||||
addrDesc := f.listenAddress.String()
|
||||
if netstackNet != nil {
|
||||
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
|
||||
}
|
||||
log.Infof("starting DNS forwarder on address=%s", addrDesc)
|
||||
|
||||
udpLn, err := f.createUDPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create UDP listener: %w", err)
|
||||
}
|
||||
|
||||
tcpLn, err := f.createTCPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// UDP server
|
||||
mux := dns.NewServeMux()
|
||||
f.mux = mux
|
||||
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||
f.dnsServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
PacketConn: udpLn,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// TCP server
|
||||
tcpMux := dns.NewServeMux()
|
||||
f.tcpMux = tcpMux
|
||||
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||
f.tcpServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "tcp",
|
||||
Handler: tcpMux,
|
||||
Listener: tcpLn,
|
||||
Handler: tcpMux,
|
||||
}
|
||||
|
||||
f.UpdateDomains(entries)
|
||||
@@ -87,26 +108,70 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
||||
errCh <- f.dnsServer.ListenAndServe()
|
||||
log.Infof("DNS UDP listener running on %s", addrDesc)
|
||||
errCh <- f.dnsServer.ActivateAndServe()
|
||||
}()
|
||||
go func() {
|
||||
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
||||
errCh <- f.tcpServer.ListenAndServe()
|
||||
log.Infof("DNS TCP listener running on %s", addrDesc)
|
||||
errCh <- f.tcpServer.ActivateAndServe()
|
||||
}()
|
||||
|
||||
// return the first error we get (e.g. bind failure or shutdown)
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenUDPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenTCPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// remove cache entries for domains that no longer appear
|
||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
}
|
||||
|
||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||
// in the old list but not present in the new list.
|
||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||
if f.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newSet := make(map[string]struct{}, len(newEntries))
|
||||
for _, e := range newEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, e := range oldEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
pattern := e.Domain.PunycodeString()
|
||||
if _, ok := newSet[pattern]; !ok {
|
||||
f.cache.unset(pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -171,6 +236,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -282,29 +348,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||
}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString(tt.configuredDomain)
|
||||
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
// Set up forwarder
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Create entries and track sets
|
||||
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Configure a single domain
|
||||
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
d, err := domain.FromString(tt.configured)
|
||||
require.NoError(t, err)
|
||||
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, _ := domain.FromString("example.com")
|
||||
@@ -648,12 +648,101 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
// Ensures that when the first query succeeds and populates the cache,
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First call resolves successfully and populates cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
// Second call fails upstream; forwarder should serve from cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
// First query: populate cache
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("9.8.7.6")
|
||||
|
||||
// Initial resolution with mixed case to populate cache
|
||||
mixedQuery := "ExAmPlE.CoM"
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Set up complex overlapping patterns
|
||||
@@ -715,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -836,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
|
||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
// Test handling of malformed query with no questions
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
@@ -4,27 +4,34 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
listenPort uint16 = 5353
|
||||
listenPortMu sync.RWMutex
|
||||
const (
|
||||
dnsTTL = 60
|
||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTTL = 60 //seconds
|
||||
)
|
||||
// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
|
||||
type wgIface interface {
|
||||
GetNet() *netstack.Net
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
type ForwarderEntry struct {
|
||||
@@ -36,24 +43,30 @@ type ForwarderEntry struct {
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
wgIface wgIface
|
||||
serverPort uint16
|
||||
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
port uint16
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
listenPortMu.RLock()
|
||||
defer listenPortMu.RUnlock()
|
||||
return listenPort
|
||||
}
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
|
||||
serverPort := nbdns.ForwarderServerPort
|
||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||
serverPort = uint16(port)
|
||||
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
|
||||
} else {
|
||||
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
|
||||
}
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
port: port,
|
||||
wgIface: wgIface,
|
||||
serverPort: serverPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,13 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.port > 0 {
|
||||
listenPortMu.Lock()
|
||||
listenPort = m.port
|
||||
listenPortMu.Unlock()
|
||||
localAddr := m.wgIface.Address().IP
|
||||
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
|
||||
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
|
||||
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -98,6 +123,20 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
localAddr := m.wgIface.Address().IP
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||
}
|
||||
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
m.unregisterNetstackServices()
|
||||
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
@@ -113,7 +152,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
func (m *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []uint16{ListenPort()},
|
||||
Values: []uint16{m.serverPort},
|
||||
}
|
||||
|
||||
if m.firewall == nil {
|
||||
@@ -122,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||
}
|
||||
m.fwRules = dnsRules
|
||||
|
||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.firewall.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
m.fwRules = dnsRules
|
||||
m.tcpRules = tcpRules
|
||||
|
||||
m.registerNetstackServices()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) registerNetstackServices() {
|
||||
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := m.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
|
||||
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
|
||||
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) unregisterNetstackServices() {
|
||||
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := m.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
|
||||
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
|
||||
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range m.fwRules {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -30,9 +29,9 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
@@ -50,11 +49,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -115,7 +115,12 @@ type EngineConfig struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerSSHAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -129,6 +134,11 @@ type EngineConfig struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// for debug bundle generation
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -173,8 +183,7 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
sshServer sshServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -193,18 +202,23 @@ type Engine struct {
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// dns forwarder port
|
||||
dnsFwdPort uint16
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -218,17 +232,7 @@ type localIpUpdater interface {
|
||||
}
|
||||
|
||||
// NewEngine creates a new Connection Engine with probes attached
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, c *profilemanager.Config) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
@@ -243,11 +247,11 @@ func NewEngine(
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
dnsFwdPort: dnsfwd.ListenPort(),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
@@ -265,6 +269,7 @@ func NewEngine(
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -289,6 +294,12 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
if err := e.stopSSHServer(); err != nil {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
@@ -299,17 +310,12 @@ func (e *Engine) Stop() error {
|
||||
e.ingressGatewayMgr = nil
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.dnsForwardMgr != nil {
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
@@ -326,9 +332,7 @@ func (e *Engine) Stop() error {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
e.jobExecutorWG.Wait() // block until job goroutines finish
|
||||
|
||||
e.close()
|
||||
|
||||
@@ -337,8 +341,6 @@ func (e *Engine) Stop() error {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -349,12 +351,52 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
func (e *Engine) calculateShutdownTimeout() time.Duration {
|
||||
peerCount := len(e.peerStore.PeersPubKey())
|
||||
|
||||
baseTimeout := 10 * time.Second
|
||||
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
|
||||
timeout := baseTimeout + perPeerTimeout
|
||||
|
||||
maxTimeout := 30 * time.Second
|
||||
if timeout > maxTimeout {
|
||||
timeout = maxTimeout
|
||||
}
|
||||
|
||||
return timeout
|
||||
}
|
||||
|
||||
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
|
||||
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
@@ -478,20 +520,21 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
e.receiveJobEvents()
|
||||
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
e.shutdownWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
defer e.shutdownWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
e.triggerClientRestart()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
@@ -507,7 +550,7 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
@@ -671,14 +714,10 @@ func (e *Engine) removeAllPeers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
||||
// removePeer closes an existing peer connection and removes a peer
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
@@ -759,9 +798,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// Read the storage-enabled flag under the syncRespMux too.
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// Store sync response if persistence is enabled
|
||||
if e.persistSyncResponse {
|
||||
if enabled {
|
||||
e.syncRespMux.Lock()
|
||||
e.latestSyncResponse = update
|
||||
e.syncRespMux.Unlock()
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
|
||||
@@ -850,6 +898,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -859,65 +912,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNil(server nbssh.Server) bool {
|
||||
return server == nil || reflect.ValueOf(server).IsNil()
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is not enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshServer = nil
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -930,8 +924,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
err := e.updateSSH(conf.GetSshConfig())
|
||||
if err != nil {
|
||||
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling SSH server setup: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -946,11 +939,84 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
e.jobExecutorWG.Add(1)
|
||||
go func() {
|
||||
defer e.jobExecutorWG.Done()
|
||||
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
|
||||
resp := mgmProto.JobResponse{
|
||||
ID: msg.ID,
|
||||
Status: mgmProto.JobStatus_failed,
|
||||
}
|
||||
switch params := msg.WorkloadParameters.(type) {
|
||||
case *mgmProto.JobRequest_Bundle:
|
||||
bundleResult, err := e.handleBundle(params.Bundle)
|
||||
if err != nil {
|
||||
log.Errorf("handling bundle: %v", err)
|
||||
resp.Reason = []byte(err.Error())
|
||||
return &resp
|
||||
}
|
||||
resp.Status = mgmProto.JobStatus_succeeded
|
||||
resp.WorkloadResults = bundleResult
|
||||
return &resp
|
||||
default:
|
||||
resp.Reason = []byte(jobexec.ErrJobNotImplemented.Error())
|
||||
return &resp
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return
|
||||
}
|
||||
log.Info("stopped receiving jobs from Management Service")
|
||||
}()
|
||||
log.Info("connecting to Management Service jobs stream")
|
||||
}
|
||||
|
||||
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
|
||||
log.Infof("handle remote debug bundle request: %s", params.String())
|
||||
syncResponse, err := e.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
|
||||
bundleDeps := debug.GeneratorDependencies{
|
||||
InternalConfig: e.config.ProfileConfig,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
}
|
||||
|
||||
bundleJobParams := debug.BundleConfig{
|
||||
Anonymize: params.Anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
LogFileCount: uint32(params.LogFileCount),
|
||||
}
|
||||
|
||||
waitFor := time.Duration(params.BundleForTime) * time.Minute
|
||||
|
||||
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, waitFor, e.config.ProfileConfig.ManagementURL.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := &mgmProto.JobResponse_Bundle{
|
||||
Bundle: &mgmProto.BundleResult{
|
||||
UploadKey: uploadKey,
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
@@ -967,6 +1033,11 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1060,10 +1131,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
@@ -1084,7 +1159,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
@@ -1121,16 +1196,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
// update SSHServer by adding remote peer SSH keys
|
||||
if !isNil(e.sshServer) {
|
||||
for _, config := range networkMap.GetRemotePeers() {
|
||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
||||
if err != nil {
|
||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||
|
||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1208,10 +1277,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) //nolint
|
||||
if forwarderPort == 0 {
|
||||
forwarderPort = nbdns.ForwarderClientPort
|
||||
}
|
||||
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
ForwarderPort: forwarderPort,
|
||||
}
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
@@ -1368,7 +1443,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1485,13 +1562,6 @@ func (e *Engine) close() {
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
if err != nil {
|
||||
@@ -1522,6 +1592,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
@@ -1667,7 +1742,7 @@ func (e *Engine) getRosenpassAddr() string {
|
||||
|
||||
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
|
||||
// and updates the status recorder with the latest states.
|
||||
func (e *Engine) RunHealthProbes() bool {
|
||||
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
e.syncMsgMux.Lock()
|
||||
|
||||
signalHealthy := e.signal.IsHealthy()
|
||||
@@ -1699,8 +1774,12 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
results := e.probeICE(stuns, turns)
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
relayHealthy := true
|
||||
@@ -1717,15 +1796,10 @@ func (e *Engine) RunHealthProbes() bool {
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
||||
return append(
|
||||
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
|
||||
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
// triggerClientRestart triggers a full client restart by cancelling the client context.
|
||||
// Note: This does NOT just restart the engine - it cancels the entire client context,
|
||||
// which causes the connect client's retry loop to create a completely new engine.
|
||||
func (e *Engine) triggerClientRestart() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -1747,7 +1821,9 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Infof("network monitor stopped")
|
||||
@@ -1757,8 +1833,8 @@ func (e *Engine) startNetworkMonitor() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
log.Infof("Network monitor: detected network change, triggering client restart")
|
||||
e.triggerClientRestart()
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1794,8 +1870,8 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// SetSyncResponsePersistence enables or disables sync response persistence
|
||||
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.syncRespMux.Lock()
|
||||
defer e.syncRespMux.Unlock()
|
||||
|
||||
if enabled == e.persistSyncResponse {
|
||||
return
|
||||
@@ -1810,20 +1886,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
|
||||
|
||||
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
|
||||
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.syncRespMux.RLock()
|
||||
enabled := e.persistSyncResponse
|
||||
latest := e.latestSyncResponse
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
if !e.persistSyncResponse {
|
||||
if !enabled {
|
||||
return nil, errors.New("sync response persistence is disabled")
|
||||
}
|
||||
|
||||
if e.latestSyncResponse == nil {
|
||||
if latest == nil {
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
|
||||
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
|
||||
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
|
||||
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to clone sync response")
|
||||
}
|
||||
@@ -1843,60 +1921,50 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
forwarderPort uint16,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.stopDNSForwarder()
|
||||
return
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
case e.dnsFwdPort != forwarderPort:
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||
e.dnsFwdPort = forwarderPort
|
||||
|
||||
default:
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.startDNSForwarder(fwdEntries)
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
e.stopDNSForwarder()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
// stop and start the forwarder to apply the new port
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSForwarder() {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
|
||||
355
client/internal/engine_ssh.go
Normal file
355
client/internal/engine_ssh.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type sshServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetStatus() (bool, []sshserver.SessionInfo)
|
||||
}
|
||||
|
||||
func (e *Engine) setupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("add SSH port redirection: %w", err)
|
||||
}
|
||||
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is disabled in config")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !sshConf.GetSshEnabled() {
|
||||
if e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is locally allowed but disabled by management server")
|
||||
}
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if e.sshServer != nil {
|
||||
log.Debug("SSH server is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||
return e.startSSHServer(nil)
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
jwtConfig := &sshserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audience: protoJWT.GetAudience(),
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
}
|
||||
|
||||
return e.startSSHServer(jwtConfig)
|
||||
}
|
||||
|
||||
return errors.New("SSH server requires valid JWT configuration")
|
||||
}
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
return nil
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
return nil // Don't fail engine startup on SSH config issues
|
||||
}
|
||||
|
||||
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||
|
||||
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to update SSH config state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||
var peerInfo []sshconfig.PeerSSHInfo
|
||||
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIP := e.extractPeerIP(peerConfig)
|
||||
hostname := e.extractHostname(peerConfig)
|
||||
|
||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||
Hostname: hostname,
|
||||
IP: peerIP,
|
||||
FQDN: peerConfig.GetFqdn(),
|
||||
})
|
||||
}
|
||||
|
||||
return peerInfo
|
||||
}
|
||||
|
||||
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||
return prefix.Addr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHostname extracts short hostname from FQDN
|
||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
fqdn := peerConfig.GetFqdn()
|
||||
if fqdn == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(fqdn, ".")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
|
||||
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||
}
|
||||
|
||||
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
statusRecorder := e.statusRecorder
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if statusRecorder == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
if len(peerState.SSHHostKey) > 0 {
|
||||
return peerState.SSHHostKey, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
log.Warnf("failed to remove SSH client config: %v", err)
|
||||
} else {
|
||||
log.Debugf("SSH client config cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
return fmt.Errorf("start SSH server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureSSHServer applies SSH configuration options to the server.
|
||||
func (e *Engine) configureSSHServer(server *sshserver.Server) {
|
||||
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
|
||||
server.SetAllowRootLogin(true)
|
||||
log.Info("SSH root login enabled")
|
||||
} else {
|
||||
server.SetAllowRootLogin(false)
|
||||
log.Info("SSH root login disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
|
||||
server.SetAllowSFTP(true)
|
||||
log.Info("SSH SFTP subsystem enabled")
|
||||
} else {
|
||||
server.SetAllowSFTP(false)
|
||||
log.Info("SSH SFTP subsystem disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
log.Info("SSH local port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowLocalPortForwarding(false)
|
||||
log.Info("SSH local port forwarding disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
|
||||
server.SetAllowRemotePortForwarding(true)
|
||||
log.Info("SSH remote port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowRemotePortForwarding(false)
|
||||
log.Info("SSH remote port forwarding disabled (default)")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("remove SSH port redirection: %w", err)
|
||||
}
|
||||
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopSSHServer() error {
|
||||
if e.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to cleanup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping SSH server")
|
||||
err := e.sshServer.Stop()
|
||||
e.sshServer = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSSHServerStatus returns the SSH server status and active sessions
|
||||
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
||||
e.syncMsgMux.Lock()
|
||||
sshServer := e.sshServer
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if sshServer == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return sshServer.GetStatus()
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user