mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-12 04:06:21 -04:00
Compare commits
47 Commits
feature/fl
...
transparen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
afbddae472 | ||
|
|
5259e5df51 | ||
|
|
ebd78e0122 | ||
|
|
cf86b9a528 | ||
|
|
ee588e1536 | ||
|
|
2a8aacc5c9 | ||
|
|
15709bc666 | ||
|
|
789b4113fe | ||
|
|
d2cdc0efec | ||
|
|
ee343d5d77 | ||
|
|
099c493b18 | ||
|
|
c1d1229ae0 | ||
|
|
94a36cb53e | ||
|
|
c7ba931466 | ||
|
|
413d95b740 | ||
|
|
332c624c55 | ||
|
|
dc160aff36 | ||
|
|
96806bf55f | ||
|
|
d33cd4c95b | ||
|
|
e2c2f64be7 | ||
|
|
cb73b94ffb | ||
|
|
1d920d700c | ||
|
|
bb85eee40a | ||
|
|
aba5d6f0d2 | ||
|
|
0588d2dbe1 | ||
|
|
14b3b77bda | ||
|
|
6da34e483c | ||
|
|
0efef671d7 | ||
|
|
435203b13b | ||
|
|
decb5dd3af | ||
|
|
28fbf96b2a | ||
|
|
9d1a37c644 | ||
|
|
5bf2372c4d | ||
|
|
c2c6396a04 | ||
|
|
aaf813fc0c | ||
|
|
d97fe84296 | ||
|
|
81f45dab21 | ||
|
|
d670e7382a | ||
|
|
cd8c686339 | ||
|
|
f5c41e3018 | ||
|
|
2477f99d89 | ||
|
|
940f530ac2 | ||
|
|
4d3e2f8ad3 | ||
|
|
5ae986e1c4 | ||
|
|
e5914e4e8b | ||
|
|
c238f5425f | ||
|
|
3c3097ea74 |
@@ -31,7 +31,7 @@ jobs:
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -33,3 +33,5 @@ infrastructure_files/setup-*.env
|
||||
vendor/
|
||||
/netbird
|
||||
client/netbird-electron/
|
||||
management/server/types/testdata/comparison/
|
||||
management/server/types/testdata/*.json
|
||||
|
||||
@@ -154,6 +154,26 @@ builds:
|
||||
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-idp-migrate
|
||||
dir: tools/idp-migrate
|
||||
env:
|
||||
- CGO_ENABLED=1
|
||||
- >-
|
||||
{{- if eq .Runtime.Goos "linux" }}
|
||||
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||
{{- end }}
|
||||
binary: netbird-idp-migrate
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -166,6 +186,10 @@ archives:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||
format: binary
|
||||
- id: netbird-idp-migrate
|
||||
builds:
|
||||
- netbird-idp-migrate
|
||||
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
|
||||
needsRestoreUp := false
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = !stateWasDown
|
||||
cmd.Println("netbird down")
|
||||
}
|
||||
|
||||
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
needsRestoreUp = false
|
||||
cmd.Println("netbird up")
|
||||
}
|
||||
|
||||
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if needsRestoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
|
||||
} else {
|
||||
cmd.Println("netbird up (restored)")
|
||||
}
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -201,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
@@ -236,7 +237,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
|
||||
@@ -25,10 +25,10 @@ func TestServiceParamsPath(t *testing.T) {
|
||||
t.Cleanup(func() { configs.StateDir = original })
|
||||
|
||||
configs.StateDir = "/var/lib/netbird"
|
||||
assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath())
|
||||
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||
|
||||
configs.StateDir = "/custom/state"
|
||||
assert.Equal(t, "/custom/state/service.json", serviceParamsPath())
|
||||
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||
}
|
||||
|
||||
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -13,6 +15,22 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMain intercepts when this test binary is run as a daemon subprocess.
|
||||
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
|
||||
// "service run ..." arguments. Since the test binary can't handle cobra CLI
|
||||
// args, it exits immediately, causing daemon -r to respawn rapidly until
|
||||
// hitting the rate limit and exiting. This makes service restart unreliable.
|
||||
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
|
||||
<-sig
|
||||
return
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
const (
|
||||
serviceStartTimeout = 10 * time.Second
|
||||
serviceStopTimeout = 5 * time.Second
|
||||
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
|
||||
logLevel = "info"
|
||||
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
|
||||
|
||||
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
|
||||
t.Cleanup(func() {
|
||||
cfg, err := newSVCConfig()
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service config: %v", err)
|
||||
return
|
||||
}
|
||||
ctxSvc, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("cleanup: create service: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If the subtests already cleaned up, there's nothing to do.
|
||||
if _, err := s.Status(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Errorf("cleanup: stop service: %v", err)
|
||||
}
|
||||
if err := s.Uninstall(); err != nil {
|
||||
t.Errorf("cleanup: uninstall service: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Install", func(t *testing.T) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
@@ -35,20 +36,27 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||
log.Info("forcing userspace firewall")
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
@@ -160,3 +168,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func forceUserspaceFirewall() bool {
|
||||
val := os.Getenv(EnvForceUserspaceFirewall)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
force, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
||||
return false
|
||||
}
|
||||
return force
|
||||
}
|
||||
|
||||
@@ -7,6 +7,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
|
||||
// native iptables/nftables is available. This only applies when the WireGuard interface
|
||||
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
|
||||
// kernel netfilter rules.
|
||||
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
|
||||
@@ -33,7 +33,6 @@ type Manager struct {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
@@ -64,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
state := &ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -203,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
@@ -286,6 +282,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
@@ -352,6 +364,28 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
|
||||
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID.
|
||||
func (m *Manager) RemoveTProxyRule(ruleID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveTProxyRule(ruleID)
|
||||
}
|
||||
|
||||
// AddUDPInspectionHook is a no-op for iptables (kernel-mode firewall has no userspace packet hooks).
|
||||
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
|
||||
|
||||
// RemoveUDPInspectionHook is a no-op for iptables.
|
||||
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
|
||||
@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
@@ -43,6 +44,7 @@ const (
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
@@ -87,6 +89,8 @@ type router struct {
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
|
||||
tproxyRules []tproxyRuleEntry
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
@@ -387,6 +391,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
@@ -396,6 +408,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
@@ -970,6 +983,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
@@ -1023,3 +1111,92 @@ func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
func (r *router) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
// AddTProxyRule adds iptables nat PREROUTING REDIRECT rules for transparent proxy interception.
|
||||
// Traffic from sources on dstPorts arriving on the WG interface is redirected
|
||||
// to the transparent proxy listener on redirectPort.
|
||||
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
portStr := fmt.Sprintf("%d", redirectPort)
|
||||
|
||||
for _, proto := range []string{"tcp", "udp"} {
|
||||
srcSpecs := r.buildSourceSpecs(sources)
|
||||
|
||||
for _, srcSpec := range srcSpecs {
|
||||
if len(dstPorts) == 0 {
|
||||
rule := append(srcSpec,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "REDIRECT",
|
||||
"--to-ports", portStr,
|
||||
)
|
||||
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
|
||||
return fmt.Errorf("add redirect rule %s/%s: %w", ruleID, proto, err)
|
||||
}
|
||||
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
|
||||
ruleID: ruleID,
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
spec: rule,
|
||||
})
|
||||
} else {
|
||||
for _, port := range dstPorts {
|
||||
rule := append(srcSpec,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"--dport", fmt.Sprintf("%d", port),
|
||||
"-j", "REDIRECT",
|
||||
"--to-ports", portStr,
|
||||
)
|
||||
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
|
||||
return fmt.Errorf("add redirect rule %s/%s/%d: %w", ruleID, proto, port, err)
|
||||
}
|
||||
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
|
||||
ruleID: ruleID,
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
spec: rule,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes all iptables REDIRECT rules for the given ruleID.
|
||||
func (r *router) RemoveTProxyRule(ruleID string) error {
|
||||
var remaining []tproxyRuleEntry
|
||||
for _, entry := range r.tproxyRules {
|
||||
if entry.ruleID != ruleID {
|
||||
remaining = append(remaining, entry)
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(entry.table, entry.chain, entry.spec...); err != nil {
|
||||
log.Debugf("remove tproxy rule %s: %v", ruleID, err)
|
||||
}
|
||||
}
|
||||
r.tproxyRules = remaining
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type tproxyRuleEntry struct {
|
||||
ruleID string
|
||||
table string
|
||||
chain string
|
||||
spec []string
|
||||
}
|
||||
|
||||
func (r *router) buildSourceSpecs(sources []netip.Prefix) [][]string {
|
||||
if len(sources) == 0 {
|
||||
return [][]string{{}} // empty spec = match any source
|
||||
}
|
||||
|
||||
specs := make([][]string, 0, len(sources))
|
||||
for _, src := range sources {
|
||||
specs = append(specs, []string{"-s", src.String()})
|
||||
}
|
||||
return specs
|
||||
}
|
||||
|
||||
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
sync.Mutex
|
||||
|
||||
|
||||
@@ -169,9 +169,33 @@ type Manager interface {
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
// AddTProxyRule adds TPROXY redirect rules for specific source CIDRs and destination ports.
|
||||
// Traffic from sources on dstPorts is redirected to the transparent proxy on redirectPort.
|
||||
// Empty dstPorts means redirect all ports.
|
||||
AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID.
|
||||
RemoveTProxyRule(ruleID string) error
|
||||
|
||||
// AddUDPInspectionHook registers a hook that inspects UDP packets before forwarding.
|
||||
// The hook receives the raw packet and returns true to drop it.
|
||||
// Used for QUIC SNI-based blocking. Returns a hook ID for removal.
|
||||
AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// RemoveUDPInspectionHook removes a previously registered inspection hook.
|
||||
RemoveUDPInspectionHook(hookID string)
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -40,7 +40,6 @@ func getTableName() string {
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
@@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return m.router.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
// AllowNetbird allows netbird interface traffic.
|
||||
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -346,6 +342,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRawOutput = "netbird-raw-out"
|
||||
chainNameRawPrerouting = "netbird-raw-pre"
|
||||
@@ -470,6 +482,28 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
|
||||
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID.
|
||||
func (m *Manager) RemoveTProxyRule(ruleID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveTProxyRule(ruleID)
|
||||
}
|
||||
|
||||
// AddUDPInspectionHook is a no-op for nftables (kernel-mode firewall has no userspace packet hooks).
|
||||
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
|
||||
|
||||
// RemoveUDPInspectionHook is a no-op for nftables.
|
||||
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
|
||||
@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
|
||||
panic("AddressFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
|
||||
@@ -36,6 +36,7 @@ const (
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
@@ -76,6 +77,7 @@ type router struct {
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
@@ -1853,6 +1855,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *router) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
@@ -2012,3 +2138,227 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddTProxyRule adds nftables TPROXY redirect rules in the mangle prerouting chain.
|
||||
// Traffic from sources on dstPorts arriving on the WG interface is redirected to
|
||||
// the transparent proxy listener on redirectPort.
|
||||
// Separate rules are created for TCP and UDP protocols.
|
||||
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
// Use the nat redirect chain for DNAT rules.
|
||||
// TPROXY doesn't work on WG kernel interfaces (socket assignment silently fails),
|
||||
// so we use DNAT to 127.0.0.1:proxy_port instead. The proxy reads the original
|
||||
// destination via SO_ORIGINAL_DST (conntrack).
|
||||
chain := r.chains[chainNameRoutingRdr]
|
||||
if chain == nil {
|
||||
return fmt.Errorf("nat redirect chain not initialized")
|
||||
}
|
||||
|
||||
for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP} {
|
||||
protoName := "tcp"
|
||||
if proto == unix.IPPROTO_UDP {
|
||||
protoName = "udp"
|
||||
}
|
||||
|
||||
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, protoName)
|
||||
|
||||
if existing, ok := r.rules[ruleKey]; ok && existing.Handle != 0 {
|
||||
if err := r.decrementSetCounter(existing); err != nil {
|
||||
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
|
||||
}
|
||||
if err := r.conn.DelRule(existing); err != nil {
|
||||
log.Debugf("remove existing tproxy rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
exprs, err := r.buildRedirectExprs(proto, sources, dstPorts, redirectPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build redirect exprs for %s: %w", protoName, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
}
|
||||
|
||||
// Accept redirected packets in the ACL input chain. After REDIRECT, the
|
||||
// destination port becomes the proxy port. Without this rule, the ACL filter
|
||||
// drops the packet. We match on ct state dnat so only REDIRECT'd connections
|
||||
// are accepted: direct connections to the proxy port are blocked.
|
||||
inputAcceptKey := fmt.Sprintf("tproxy-%s-input", ruleID)
|
||||
if _, ok := r.rules[inputAcceptKey]; !ok {
|
||||
inputChain := &nftables.Chain{
|
||||
Name: "netbird-acl-input-rules",
|
||||
Table: r.workTable,
|
||||
}
|
||||
r.rules[inputAcceptKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: inputChain,
|
||||
Exprs: []expr.Any{
|
||||
// Only accept connections that were REDIRECT'd (ct status dnat)
|
||||
&expr.Ct{Register: 1, Key: expr.CtKeySTATUS},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(0x20), // IPS_DST_NAT
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
// Accept both TCP and UDP redirected to the proxy port.
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(redirectPort),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(inputAcceptKey),
|
||||
})
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush tproxy rules for %s: %w", ruleID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID (both TCP and UDP variants).
|
||||
func (r *router) RemoveTProxyRule(ruleID string) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var removed int
|
||||
for _, suffix := range []string{"tcp", "udp", "input"} {
|
||||
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, suffix)
|
||||
|
||||
rule, ok := r.rules[ruleKey]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
delete(r.rules, ruleKey)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
log.Debugf("delete tproxy rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
removed++
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush tproxy rule removal for %s: %w", ruleID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildRedirectExprs builds nftables expressions for a REDIRECT rule.
|
||||
// Matches WG interface ingress, source CIDRs, destination ports, then REDIRECTs to the proxy port.
|
||||
func (r *router) buildRedirectExprs(proto uint8, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) ([]expr.Any, error) {
|
||||
var exprs []expr.Any
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname(r.wgIface.Name())},
|
||||
)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
|
||||
)
|
||||
|
||||
// Source CIDRs use the named ipset shared with route rules.
|
||||
if len(sources) > 0 {
|
||||
srcSet := firewall.NewPrefixSet(sources)
|
||||
srcExprs, err := r.getIpSet(srcSet, sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get source ipset: %w", err)
|
||||
}
|
||||
exprs = append(exprs, srcExprs...)
|
||||
}
|
||||
|
||||
if len(dstPorts) == 1 {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(dstPorts[0]),
|
||||
},
|
||||
)
|
||||
} else if len(dstPorts) > 1 {
|
||||
setElements := make([]nftables.SetElement, len(dstPorts))
|
||||
for i, p := range dstPorts {
|
||||
setElements[i] = nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)}
|
||||
}
|
||||
portSet := &nftables.Set{
|
||||
Table: r.workTable,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetService,
|
||||
}
|
||||
if err := r.conn.AddSet(portSet, setElements); err != nil {
|
||||
return nil, fmt.Errorf("create port set: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: portSet.Name,
|
||||
SetID: portSet.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// REDIRECT to local proxy port. Changes the destination to the interface's
|
||||
// primary address + specified port. Conntrack tracks the original destination,
|
||||
// readable via SO_ORIGINAL_DST.
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(redirectPort)},
|
||||
&expr.Redir{
|
||||
RegisterProtoMin: 1,
|
||||
},
|
||||
)
|
||||
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
func (i *InterfaceState) IsUserspaceBind() bool {
|
||||
return i.UserspaceBind
|
||||
}
|
||||
|
||||
type ShutdownState struct {
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
}
|
||||
|
||||
@@ -140,6 +140,17 @@ type Manager struct {
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
tcpHookOut atomic.Pointer[packetHook]
|
||||
}
|
||||
|
||||
// packetHook stores a registered hook for a specific IP:port.
|
||||
type packetHook struct {
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
fn func([]byte) bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
|
||||
maps.Clear(m.incomingRules)
|
||||
maps.Clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
m.tcpHookOut.Store(nil)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
@@ -628,6 +641,45 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// AddTProxyRule delegates to the native firewall for TPROXY rules.
|
||||
// In userspace mode (no native firewall), this is a no-op since the
|
||||
// forwarder intercepts traffic directly.
|
||||
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
|
||||
}
|
||||
|
||||
// AddUDPInspectionHook registers a hook for QUIC/UDP inspection via the packet filter.
|
||||
func (m *Manager) AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string {
|
||||
m.SetUDPPacketHook(netip.Addr{}, dstPort, hook)
|
||||
return "udp-inspection"
|
||||
}
|
||||
|
||||
// RemoveUDPInspectionHook removes a previously registered inspection hook.
|
||||
func (m *Manager) RemoveUDPInspectionHook(_ string) {
|
||||
m.SetUDPPacketHook(netip.Addr{}, 0, nil)
|
||||
}
|
||||
|
||||
// RemoveTProxyRule delegates to the native firewall for TPROXY rules.
|
||||
func (m *Manager) RemoveTProxyRule(ruleID string) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveTProxyRule(ruleID)
|
||||
}
|
||||
|
||||
// IsLocalIP reports whether the given IP belongs to the local machine.
|
||||
func (m *Manager) IsLocalIP(ip netip.Addr) bool {
|
||||
return m.localipmanager.IsLocalIP(ip)
|
||||
}
|
||||
|
||||
// GetForwarder returns the userspace packet forwarder, or nil if not initialized.
|
||||
func (m *Manager) GetForwarder() *forwarder.Forwarder {
|
||||
return m.forwarder.Load()
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
@@ -713,6 +765,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
@@ -895,38 +950,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// Check specific destination IP first
|
||||
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check IPv4 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
if h.ip == dstIP && h.port == dport {
|
||||
return h.fn(packetData)
|
||||
}
|
||||
|
||||
// Check IPv6 unspecified address
|
||||
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||
for _, rule := range rules {
|
||||
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||
return rule.udpHook(packetData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1278,12 +1316,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.mgmtId, rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
@@ -1342,65 +1374,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
return sourceMatched
|
||||
}
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
protoLayer: layers.LayerTypeUDP,
|
||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
udpHook: hook,
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.udpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
if ip.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
if in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.incomingRules[r.ip][r.id] = r
|
||||
} else {
|
||||
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||
}
|
||||
m.outgoingRules[r.ip][r.id] = r
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
m.udpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Check incoming hooks (stored in allow rules)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.tcpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
// Check outgoing hooks
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
delete(arr, r.id)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
m.tcpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
}{
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
}
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
var called bool
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(8000), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
// Incoming UDP hooks are stored in allow rules map
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules[tt.ip]) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load())
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
if tt.dPort != addedRule.dPort.Values[0] {
|
||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||
return
|
||||
}
|
||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||
return
|
||||
}
|
||||
if addedRule.udpHook == nil {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
var called bool
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
|
||||
called = true
|
||||
return true
|
||||
})
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(53), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
assert.Nil(t, manager.tcpHookOut.Load())
|
||||
}
|
||||
|
||||
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
|
||||
|
||||
if !found {
|
||||
t.Fatalf("The hook was not added properly.")
|
||||
}
|
||||
|
||||
// Now remove the packet hook
|
||||
err = manager.RemovePacketHook(hookID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to remove hook: %s", err)
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
|
||||
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
|
||||
}
|
||||
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
}
|
||||
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
manager.SetUDPPacketHook(
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
return true
|
||||
},
|
||||
)
|
||||
require.NotEmpty(t, hookID)
|
||||
|
||||
// Create test UDP packet
|
||||
ipv4 := &layers.IPv4{
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
@@ -46,6 +48,10 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
// proxy is the optional inspection engine.
|
||||
// When set, TCP connections are handed to the engine for protocol detection
|
||||
// and rule evaluation. Swapped atomically for lock-free hot-path access.
|
||||
proxy atomic.Pointer[inspect.Proxy]
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -79,7 +85,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
return nil, fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
@@ -155,6 +161,13 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetProxy sets the inspection engine. When set, TCP connections are handed
|
||||
// to it for protocol detection and rule evaluation instead of direct relay.
|
||||
// Pass nil to disable inspection.
|
||||
func (f *Forwarder) SetProxy(p *inspect.Proxy) {
|
||||
f.proxy.Store(p)
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
@@ -167,6 +180,25 @@ func (f *Forwarder) Stop() {
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
// CheckUDPPacket inspects a UDP payload against proxy rules before injection.
|
||||
// This is called by the filter for QUIC SNI-based blocking.
|
||||
// Returns true if the packet should be allowed, false if it should be dropped.
|
||||
func (f *Forwarder) CheckUDPPacket(payload []byte, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) bool {
|
||||
p := f.proxy.Load()
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
dst := netip.AddrPortFrom(dstIP, dstPort)
|
||||
src := inspect.SourceInfo{
|
||||
IP: srcIP,
|
||||
PolicyID: inspect.PolicyID(ruleID),
|
||||
}
|
||||
|
||||
action := p.HandleUDPPacket(payload, dst, src)
|
||||
return action != inspect.ActionBlock
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
@@ -23,6 +24,86 @@ import (
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
// If the inspection engine is configured, accept the connection first and hand it off.
|
||||
if p := f.proxy.Load(); p != nil {
|
||||
f.handleTCPWithInspection(r, id, p)
|
||||
return
|
||||
}
|
||||
|
||||
f.handleTCPDirect(r, id)
|
||||
}
|
||||
|
||||
// handleTCPWithInspection accepts the connection and hands it to the inspection
|
||||
// engine. For allow decisions, the forwarder does its own relay (passthrough).
|
||||
// For block/inspect, the engine handles everything internally.
|
||||
func (f *Forwarder) handleTCPWithInspection(r *tcp.ForwarderRequest, id stack.TransportEndpointID, p *inspect.Proxy) {
|
||||
flowID := uuid.New()
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Error1("forwarder: create TCP endpoint for inspection: %v", epErr)
|
||||
r.Complete(true)
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
srcIP := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIP := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
dst := netip.AddrPortFrom(dstIP, id.LocalPort)
|
||||
|
||||
var policyID []byte
|
||||
if ruleID, ok := f.getRuleID(srcIP, dstIP, id.RemotePort, id.LocalPort); ok {
|
||||
policyID = ruleID
|
||||
}
|
||||
|
||||
src := inspect.SourceInfo{
|
||||
IP: srcIP,
|
||||
PolicyID: inspect.PolicyID(policyID),
|
||||
}
|
||||
|
||||
f.logger.Trace1("forwarder: handing TCP %v to inspection engine", epID(id))
|
||||
|
||||
go func() {
|
||||
result, err := p.InspectTCP(f.ctx, inConn, dst, src)
|
||||
if err != nil && err != inspect.ErrBlocked {
|
||||
f.logger.Debug2("forwarder: inspection error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
// Passthrough: engine returned allow, forwarder does the relay.
|
||||
if result.PassthroughConn != nil {
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, dialErr := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if dialErr != nil {
|
||||
f.logger.Trace2("forwarder: passthrough dial error for %v: %v", epID(id), dialErr)
|
||||
if closeErr := result.PassthroughConn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: close passthrough conn: %v", closeErr)
|
||||
}
|
||||
ep.Close()
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
return
|
||||
}
|
||||
f.proxyTCPPassthrough(id, result.PassthroughConn, outConn, ep, flowID)
|
||||
return
|
||||
}
|
||||
|
||||
// Engine handled it (block/inspect/HTTP). Capture stats and clean up.
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
ep.Close()
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, rxPackets, txPackets)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleTCPDirect handles TCP connections with direct relay (no proxy).
|
||||
func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportEndpointID) {
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
@@ -42,7 +123,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
@@ -55,7 +135,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
// Complete the handshake
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
@@ -73,7 +152,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
@@ -132,6 +210,66 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
// proxyTCPPassthrough relays traffic between a peeked inbound connection
|
||||
// (from the inspection engine passthrough) and the outbound connection.
|
||||
// It accepts net.Conn for inConn since the inspection engine wraps it in a peekConn.
|
||||
func (f *Forwarder) proxyTCPPassthrough(id stack.TransportEndpointID, inConn net.Conn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: passthrough inConn close: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: passthrough outConn close: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesIn int64
|
||||
bytesOut int64
|
||||
errIn error
|
||||
errOut error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesIn, errIn = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
bytesOut, errOut = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errIn != nil && !isClosedError(errIn) {
|
||||
f.logger.Error2("proxyTCPPassthrough: copy error (in→out) for %s: %v", epID(id), errIn)
|
||||
}
|
||||
if errOut != nil && !isClosedError(errOut) {
|
||||
f.logger.Error2("proxyTCPPassthrough: copy error (out→in) for %s: %v", epID(id), errOut)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace5("forwarder: passthrough TCP %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesOut, txPackets, bytesIn)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesOut), uint64(bytesIn), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interfaces: %v", err)
|
||||
} else {
|
||||
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
|
||||
// case where an interface comes up between refreshes.
|
||||
for _, intf := range interfaces {
|
||||
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||
}
|
||||
|
||||
@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||
}
|
||||
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
|
||||
@@ -18,9 +18,7 @@ type PeerRule struct {
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
|
||||
udpHook func([]byte) bool
|
||||
drop bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
|
||||
return true // drop (intercepted by hook)
|
||||
})
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
|
||||
@@ -15,14 +15,17 @@ type PacketFilter interface {
|
||||
// FilterInbound filter incoming packets from external sources to host
|
||||
FilterInbound(packetData []byte, size int) bool
|
||||
|
||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||
//
|
||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||
// Hook function receives raw network packet data as argument.
|
||||
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one UDP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
|
||||
// RemovePacketHook removes hook by ID
|
||||
RemovePacketHook(hookID string) error
|
||||
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
|
||||
// Hook function returns true if the packet should be dropped.
|
||||
// Only one TCP hook is supported; calling again replaces the previous hook.
|
||||
// Pass nil hook to remove.
|
||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||
}
|
||||
|
||||
// FilteredDevice to override Read or Write of packets
|
||||
|
||||
@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||
// SetUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemovePacketHook mocks base method.
|
||||
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemovePacketHook indicates an expected call of RemovePacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPacketFilter is a mock of PacketFilter interface.
|
||||
type MockPacketFilter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPacketFilterMockRecorder
|
||||
}
|
||||
|
||||
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
|
||||
type MockPacketFilterMockRecorder struct {
|
||||
mock *MockPacketFilter
|
||||
}
|
||||
|
||||
// NewMockPacketFilter creates a new mock instance.
|
||||
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
|
||||
mock := &MockPacketFilter{ctrl: ctrl}
|
||||
mock.recorder = &MockPacketFilterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AddUDPPacketHook mocks base method.
|
||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
|
||||
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||
}
|
||||
|
||||
// FilterInbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterInbound indicates an expected call of FilterInbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||
}
|
||||
|
||||
// FilterOutbound mocks base method.
|
||||
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||
}
|
||||
|
||||
// SetNetwork mocks base method.
|
||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
||||
}
|
||||
|
||||
// SetNetwork indicates an expected call of SetNetwork.
|
||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
||||
}
|
||||
212
client/inspect/config.go
Normal file
212
client/inspect/config.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// InspectResult holds the outcome of connection inspection.
|
||||
type InspectResult struct {
|
||||
// Action is the rule evaluation result.
|
||||
Action Action
|
||||
// PassthroughConn is the client connection with buffered peeked bytes.
|
||||
// Non-nil only when Action is ActionAllow and the caller should relay
|
||||
// (TLS passthrough or non-HTTP/TLS protocol). The caller takes ownership
|
||||
// and is responsible for closing this connection.
|
||||
PassthroughConn net.Conn
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultTProxyPort is the default TPROXY listener port for kernel mode.
|
||||
// Override with NB_TPROXY_PORT environment variable.
|
||||
DefaultTProxyPort = 22080
|
||||
)
|
||||
|
||||
// Action determines how the proxy handles a matched connection.
|
||||
type Action string
|
||||
|
||||
const (
|
||||
// ActionAllow passes the connection through without decryption.
|
||||
ActionAllow Action = "allow"
|
||||
// ActionBlock denies the connection.
|
||||
ActionBlock Action = "block"
|
||||
// ActionInspect decrypts (MITM) and inspects the connection.
|
||||
ActionInspect Action = "inspect"
|
||||
)
|
||||
|
||||
// ProxyMode determines the proxy operating mode.
|
||||
type ProxyMode string
|
||||
|
||||
const (
|
||||
// ModeBuiltin uses the built-in proxy with rules and optional ICAP.
|
||||
ModeBuiltin ProxyMode = "builtin"
|
||||
// ModeEnvoy runs a local envoy sidecar for L7 processing.
|
||||
// Go manages envoy lifecycle, config generation, and rule evaluation.
|
||||
// USP path forwards via PROXY protocol v2; kernel path uses nftables redirect.
|
||||
ModeEnvoy ProxyMode = "envoy"
|
||||
// ModeExternal forwards all traffic to an external proxy.
|
||||
ModeExternal ProxyMode = "external"
|
||||
)
|
||||
|
||||
// PolicyID is the management policy identifier associated with a connection.
|
||||
type PolicyID []byte
|
||||
|
||||
// MatchDomain reports whether target matches the pattern.
|
||||
// If pattern starts with "*.", it matches any subdomain (but not the base itself).
|
||||
// Otherwise it requires an exact match.
|
||||
func MatchDomain(pattern, target domain.Domain) bool {
|
||||
p := pattern.PunycodeString()
|
||||
t := target.PunycodeString()
|
||||
|
||||
if strings.HasPrefix(p, "*.") {
|
||||
base := p[2:]
|
||||
return strings.HasSuffix(t, "."+base)
|
||||
}
|
||||
|
||||
return p == t
|
||||
}
|
||||
|
||||
// SourceInfo carries source identity context for rule evaluation.
|
||||
// The source may be a direct WireGuard peer or a host behind
|
||||
// a site-to-site gateway.
|
||||
type SourceInfo struct {
|
||||
// IP is the original source address from the packet.
|
||||
IP netip.Addr
|
||||
// PolicyID is the management policy that allowed this traffic
|
||||
// through route ACLs.
|
||||
PolicyID PolicyID
|
||||
}
|
||||
|
||||
// ProtoType identifies a protocol handled by the proxy.
|
||||
type ProtoType string
|
||||
|
||||
const (
|
||||
ProtoHTTP ProtoType = "http"
|
||||
ProtoHTTPS ProtoType = "https"
|
||||
ProtoH2 ProtoType = "h2"
|
||||
ProtoH3 ProtoType = "h3"
|
||||
ProtoWebSocket ProtoType = "websocket"
|
||||
ProtoOther ProtoType = "other"
|
||||
)
|
||||
|
||||
// Rule defines a proxy inspection/filtering rule.
|
||||
type Rule struct {
|
||||
// ID uniquely identifies this rule.
|
||||
ID id.RuleID
|
||||
// Sources are the source CIDRs this rule applies to.
|
||||
// Includes both direct peer IPs and routed networks behind gateways.
|
||||
Sources []netip.Prefix
|
||||
// Domains are the destination domain patterns to match (via SNI or Host header).
|
||||
// Supports exact match ("example.com") and wildcard ("*.example.com").
|
||||
Domains []domain.Domain
|
||||
// Networks are the destination CIDRs to match.
|
||||
Networks []netip.Prefix
|
||||
// Ports are the destination ports to match. Empty means all ports.
|
||||
Ports []uint16
|
||||
// Protocols restricts which protocols this rule applies to.
|
||||
// Empty means all protocols.
|
||||
Protocols []ProtoType
|
||||
// Paths are URL path patterns to match (HTTP only, requires inspect for HTTPS).
|
||||
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
|
||||
// Empty means all paths.
|
||||
Paths []string
|
||||
// Action determines what to do with matched connections.
|
||||
Action Action
|
||||
// Priority controls evaluation order. Lower values are evaluated first.
|
||||
Priority int
|
||||
}
|
||||
|
||||
// ICAPConfig holds ICAP service configuration.
|
||||
type ICAPConfig struct {
|
||||
// ReqModURL is the ICAP REQMOD service URL (e.g., icap://server:1344/reqmod).
|
||||
ReqModURL *url.URL
|
||||
// RespModURL is the ICAP RESPMOD service URL (e.g., icap://server:1344/respmod).
|
||||
RespModURL *url.URL
|
||||
// MaxConnections is the connection pool size. Zero uses a default.
|
||||
MaxConnections int
|
||||
}
|
||||
|
||||
// TLSConfig holds the MITM CA configuration for TLS inspection.
|
||||
type TLSConfig struct {
|
||||
// CA is the certificate authority used to sign dynamic certificates.
|
||||
CA *x509.Certificate
|
||||
// CAKey is the CA's private key.
|
||||
CAKey crypto.PrivateKey
|
||||
}
|
||||
|
||||
// Config holds the transparent proxy configuration.
|
||||
type Config struct {
|
||||
// Enabled controls whether the proxy is active.
|
||||
Enabled bool
|
||||
// Mode selects built-in or external proxy operation.
|
||||
Mode ProxyMode
|
||||
// ExternalURL is the upstream proxy URL for ModeExternal.
|
||||
// Supports http:// and socks5:// schemes.
|
||||
ExternalURL *url.URL
|
||||
|
||||
// DefaultAction applies when no rule matches a connection.
|
||||
DefaultAction Action
|
||||
|
||||
// RedirectSources are the source CIDRs whose traffic should be intercepted.
|
||||
// Admin decides: "activate for these users/subnets."
|
||||
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
|
||||
RedirectSources []netip.Prefix
|
||||
// RedirectPorts are the destination ports to intercept. Empty means all ports.
|
||||
RedirectPorts []uint16
|
||||
|
||||
// Rules are the proxy inspection/filtering rules, evaluated in Priority order.
|
||||
Rules []Rule
|
||||
|
||||
// ICAP holds ICAP service configuration. Nil disables ICAP.
|
||||
ICAP *ICAPConfig
|
||||
// TLS holds the MITM CA. Nil means no MITM capability (ActionInspect rules ignored).
|
||||
TLS *TLSConfig
|
||||
|
||||
// Envoy configuration (ModeEnvoy only)
|
||||
Envoy *EnvoyConfig
|
||||
|
||||
// ListenAddr is the TPROXY listen address for kernel mode.
|
||||
// Zero value disables the TPROXY listener.
|
||||
ListenAddr netip.AddrPort
|
||||
// WGNetwork is the WireGuard overlay network prefix.
|
||||
// The proxy blocks dialing destinations inside this network.
|
||||
WGNetwork netip.Prefix
|
||||
// LocalIPChecker reports whether an IP belongs to the routing peer.
|
||||
// Used to prevent SSRF to local services. May be nil.
|
||||
LocalIPChecker LocalIPChecker
|
||||
}
|
||||
|
||||
// EnvoyConfig holds configuration for the envoy sidecar mode.
|
||||
type EnvoyConfig struct {
|
||||
// BinaryPath is the path to the envoy binary.
|
||||
// Empty means search $PATH for "envoy".
|
||||
BinaryPath string
|
||||
// AdminPort is the port for envoy's admin API (health checks, stats).
|
||||
// Zero means auto-assign.
|
||||
AdminPort uint16
|
||||
// Snippets are user-provided config fragments merged into the generated bootstrap.
|
||||
Snippets *EnvoySnippets
|
||||
}
|
||||
|
||||
// EnvoySnippets holds user-provided YAML fragments for envoy config customization.
|
||||
// Only safe snippet types are allowed: filters (HTTP and network) and clusters
|
||||
// needed as dependencies for filter services. Listeners and bootstrap overrides
|
||||
// are not exposed since we manage the listener and bootstrap.
|
||||
type EnvoySnippets struct {
|
||||
// HTTPFilters is YAML injected into the HCM filter chain before the router filter.
|
||||
// Used for ext_authz, rate limiting, Lua, Wasm, RBAC, JWT auth, etc.
|
||||
HTTPFilters string
|
||||
// NetworkFilters is YAML injected into the TLS filter chain before tcp_proxy.
|
||||
// Used for network-level RBAC, rate limiting, ext_authz on raw TCP.
|
||||
NetworkFilters string
|
||||
// Clusters is YAML for additional upstream clusters referenced by filters.
|
||||
// Needed when filters call external services (ext_authz backend, rate limit service).
|
||||
Clusters string
|
||||
}
|
||||
93
client/inspect/config_test.go
Normal file
93
client/inspect/config_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func TestMatchDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
target string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
pattern: "example.com",
|
||||
target: "example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact no match",
|
||||
pattern: "example.com",
|
||||
target: "other.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard matches subdomain",
|
||||
pattern: "*.example.com",
|
||||
target: "foo.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard matches deep subdomain",
|
||||
pattern: "*.example.com",
|
||||
target: "a.b.c.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard does not match base",
|
||||
pattern: "*.example.com",
|
||||
target: "example.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard does not match unrelated",
|
||||
pattern: "*.example.com",
|
||||
target: "foo.other.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive exact match",
|
||||
pattern: "Example.COM",
|
||||
target: "example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive wildcard match",
|
||||
pattern: "*.Example.COM",
|
||||
target: "FOO.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard does not match partial suffix",
|
||||
pattern: "*.example.com",
|
||||
target: "notexample.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unicode domain punycode match",
|
||||
pattern: "*.münchen.de",
|
||||
target: "sub.xn--mnchen-3ya.de",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pattern, err := domain.FromString(tt.pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
target, err := domain.FromString(tt.target)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := MatchDomain(pattern, target)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
25
client/inspect/dialer_linux.go
Normal file
25
client/inspect/dialer_linux.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// newOutboundDialer creates a net.Dialer that clears the socket fwmark.
|
||||
// In kernel TPROXY mode, accepted connections inherit the TPROXY fwmark.
|
||||
// Without clearing it, outbound connections from the proxy would match
|
||||
// the ip rule (fwmark -> local loopback) and loop back to the proxy
|
||||
// instead of reaching the real destination.
|
||||
func newOutboundDialer() net.Dialer {
|
||||
return net.Dialer{
|
||||
Control: func(_, _ string, c syscall.RawConn) error {
|
||||
var sockErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
sockErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, 0)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return sockErr
|
||||
},
|
||||
}
|
||||
}
|
||||
11
client/inspect/dialer_other.go
Normal file
11
client/inspect/dialer_other.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !linux
|
||||
|
||||
package inspect
|
||||
|
||||
import "net"
|
||||
|
||||
// newOutboundDialer returns a plain dialer on non-Linux platforms.
|
||||
// TPROXY is Linux-only, so no fwmark clearing is needed.
|
||||
func newOutboundDialer() net.Dialer {
|
||||
return net.Dialer{}
|
||||
}
|
||||
298
client/inspect/envoy.go
Normal file
298
client/inspect/envoy.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
envoyStartTimeout = 15 * time.Second
|
||||
envoyHealthInterval = 500 * time.Millisecond
|
||||
envoyStopTimeout = 10 * time.Second
|
||||
envoyDrainTime = 5
|
||||
)
|
||||
|
||||
// envoyManager manages the lifecycle of an envoy sidecar process.
|
||||
type envoyManager struct {
|
||||
log *log.Entry
|
||||
cmd *exec.Cmd
|
||||
configPath string
|
||||
listenPort uint16
|
||||
adminPort uint16
|
||||
cancel context.CancelFunc
|
||||
|
||||
blockPagePath string
|
||||
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// startEnvoy finds the envoy binary, generates config, and spawns the process.
|
||||
// It blocks until envoy reports healthy or the timeout expires.
|
||||
func startEnvoy(ctx context.Context, logger *log.Entry, config Config) (*envoyManager, error) {
|
||||
envCfg := config.Envoy
|
||||
if envCfg == nil {
|
||||
return nil, fmt.Errorf("envoy config is nil")
|
||||
}
|
||||
|
||||
binaryPath, err := findEnvoyBinary(envCfg.BinaryPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find envoy binary: %w", err)
|
||||
}
|
||||
|
||||
// Pick admin port
|
||||
adminPort := envCfg.AdminPort
|
||||
if adminPort == 0 {
|
||||
p, err := findFreePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find free admin port: %w", err)
|
||||
}
|
||||
adminPort = p
|
||||
}
|
||||
|
||||
// Pick listener port
|
||||
listenPort, err := findFreePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find free listener port: %w", err)
|
||||
}
|
||||
|
||||
// Use a private temp directory (0700) to prevent local attackers from
|
||||
// replacing the config file between write and envoy read.
|
||||
configDir, err := os.MkdirTemp("", "nb-envoy-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create envoy config directory: %w", err)
|
||||
}
|
||||
|
||||
// Write the block page HTML for envoy's direct_response to reference.
|
||||
blockPagePath := filepath.Join(configDir, "block.html")
|
||||
blockHTML := fmt.Sprintf(blockPageHTML, "blocked domain", "this domain")
|
||||
if err := os.WriteFile(blockPagePath, []byte(blockHTML), 0600); err != nil {
|
||||
return nil, fmt.Errorf("write envoy block page: %w", err)
|
||||
}
|
||||
|
||||
// Generate config with the block page path embedded.
|
||||
bootstrap, err := generateBootstrap(config, listenPort, adminPort, blockPagePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate envoy bootstrap: %w", err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(configDir, "bootstrap.yaml")
|
||||
if err := os.WriteFile(configPath, bootstrap, 0600); err != nil {
|
||||
return nil, fmt.Errorf("write envoy config: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
cmd := exec.CommandContext(ctx, binaryPath,
|
||||
"-c", configPath,
|
||||
"--drain-time-s", fmt.Sprintf("%d", envoyDrainTime),
|
||||
)
|
||||
|
||||
// Pipe envoy output to our logger.
|
||||
cmd.Stdout = &logWriter{entry: logger, level: log.DebugLevel}
|
||||
cmd.Stderr = &logWriter{entry: logger, level: log.WarnLevel}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
os.Remove(configPath)
|
||||
return nil, fmt.Errorf("start envoy: %w", err)
|
||||
}
|
||||
|
||||
mgr := &envoyManager{
|
||||
log: logger,
|
||||
cmd: cmd,
|
||||
configPath: configPath,
|
||||
listenPort: listenPort,
|
||||
adminPort: adminPort,
|
||||
blockPagePath: blockPagePath,
|
||||
cancel: cancel,
|
||||
running: true,
|
||||
}
|
||||
|
||||
// Wait for envoy to become healthy.
|
||||
if err := mgr.waitHealthy(ctx); err != nil {
|
||||
mgr.Stop()
|
||||
return nil, fmt.Errorf("wait for envoy readiness: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("inspect: envoy started (pid=%d, listen=%d, admin=%d)", cmd.Process.Pid, listenPort, adminPort)
|
||||
|
||||
// Monitor process exit in background.
|
||||
go mgr.monitor()
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// ListenAddr returns the address envoy listens on for forwarded connections.
|
||||
func (m *envoyManager) ListenAddr() netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), m.listenPort)
|
||||
}
|
||||
|
||||
// AdminAddr returns the envoy admin API address.
|
||||
func (m *envoyManager) AdminAddr() string {
|
||||
return fmt.Sprintf("127.0.0.1:%d", m.adminPort)
|
||||
}
|
||||
|
||||
// Reload writes a new config and sends SIGHUP to envoy.
|
||||
func (m *envoyManager) Reload(config Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return fmt.Errorf("envoy is not running")
|
||||
}
|
||||
|
||||
bootstrap, err := generateBootstrap(config, m.listenPort, m.adminPort, m.blockPagePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate envoy bootstrap: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.configPath, bootstrap, 0600); err != nil {
|
||||
return fmt.Errorf("write envoy config: %w", err)
|
||||
}
|
||||
|
||||
if err := signalReload(m.cmd.Process); err != nil {
|
||||
return fmt.Errorf("signal envoy reload: %w", err)
|
||||
}
|
||||
|
||||
m.log.Debugf("inspect: envoy config reloaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Healthy checks the envoy admin API /ready endpoint.
|
||||
func (m *envoyManager) Healthy() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/ready", m.AdminAddr()))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// Stop terminates the envoy process and cleans up.
|
||||
func (m *envoyManager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
m.running = false
|
||||
|
||||
m.cancel()
|
||||
|
||||
if m.cmd.Process != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.cmd.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(envoyStopTimeout):
|
||||
m.log.Warnf("inspect: envoy did not exit in %s, killing", envoyStopTimeout)
|
||||
m.cmd.Process.Kill()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
os.RemoveAll(filepath.Dir(m.configPath))
|
||||
m.log.Infof("inspect: envoy stopped")
|
||||
}
|
||||
|
||||
// waitHealthy polls the admin API until envoy is ready or timeout.
|
||||
func (m *envoyManager) waitHealthy(ctx context.Context) error {
|
||||
deadline := time.After(envoyStartTimeout)
|
||||
ticker := time.NewTicker(envoyHealthInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-deadline:
|
||||
return fmt.Errorf("envoy not ready after %s", envoyStartTimeout)
|
||||
case <-ticker.C:
|
||||
if m.Healthy() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitor watches for unexpected envoy exits.
|
||||
func (m *envoyManager) monitor() {
|
||||
err := m.cmd.Wait()
|
||||
|
||||
m.mu.Lock()
|
||||
wasRunning := m.running
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
|
||||
if wasRunning {
|
||||
m.log.Errorf("inspect: envoy exited unexpectedly: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// findEnvoyBinary resolves the envoy binary path.
|
||||
func findEnvoyBinary(configPath string) (string, error) {
|
||||
if configPath != "" {
|
||||
if _, err := os.Stat(configPath); err != nil {
|
||||
return "", fmt.Errorf("envoy binary not found at %s: %w", configPath, err)
|
||||
}
|
||||
return configPath, nil
|
||||
}
|
||||
|
||||
path, err := exec.LookPath("envoy")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("envoy not found in PATH: %w", err)
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// findFreePort asks the OS for an available TCP port.
|
||||
func findFreePort() (uint16, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
port := uint16(ln.Addr().(*net.TCPAddr).Port)
|
||||
ln.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// logWriter adapts log.Entry to io.Writer for piping process output.
|
||||
type logWriter struct {
|
||||
entry *log.Entry
|
||||
level log.Level
|
||||
}
|
||||
|
||||
func (w *logWriter) Write(p []byte) (int, error) {
|
||||
msg := strings.TrimRight(string(p), "\n\r")
|
||||
if msg == "" {
|
||||
return len(p), nil
|
||||
}
|
||||
switch w.level {
|
||||
case log.WarnLevel:
|
||||
w.entry.Warn(msg)
|
||||
default:
|
||||
w.entry.Debug(msg)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Ensure logWriter satisfies io.Writer.
|
||||
var _ io.Writer = (*logWriter)(nil)
|
||||
382
client/inspect/envoy_config.go
Normal file
382
client/inspect/envoy_config.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// envoyBootstrapTmpl generates the full envoy bootstrap with rule translation.
|
||||
// TLS rules become per-SNI filter chains; HTTP rules become per-domain virtual hosts.
|
||||
var envoyBootstrapTmpl = template.Must(template.New("bootstrap").Funcs(template.FuncMap{
|
||||
"quote": func(s string) string { return fmt.Sprintf("%q", s) },
|
||||
}).Parse(`node:
|
||||
id: netbird-inspect
|
||||
cluster: netbird
|
||||
admin:
|
||||
address:
|
||||
socket_address:
|
||||
address: 127.0.0.1
|
||||
port_value: {{.AdminPort}}
|
||||
static_resources:
|
||||
listeners:
|
||||
- name: inspect_listener
|
||||
address:
|
||||
socket_address:
|
||||
address: 127.0.0.1
|
||||
port_value: {{.ListenPort}}
|
||||
listener_filters:
|
||||
- name: envoy.filters.listener.proxy_protocol
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol
|
||||
- name: envoy.filters.listener.tls_inspector
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector
|
||||
filter_chains:
|
||||
{{- /* TLS filter chains: per-SNI block/allow + default */ -}}
|
||||
{{- range .TLSChains}}
|
||||
- filter_chain_match:
|
||||
transport_protocol: tls
|
||||
{{- if .ServerNames}}
|
||||
server_names:
|
||||
{{- range .ServerNames}}
|
||||
- {{quote .}}
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
filters:
|
||||
{{$.NetworkFiltersSnippet}} - name: envoy.filters.network.tcp_proxy
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy
|
||||
stat_prefix: {{.StatPrefix}}
|
||||
cluster: original_dst
|
||||
access_log:
|
||||
- name: envoy.access_loggers.stderr
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
|
||||
log_format:
|
||||
text_format: "[%START_TIME%] tcp %DOWNSTREAM_REMOTE_ADDRESS% -> %UPSTREAM_HOST% %RESPONSE_FLAGS% %DURATION%ms\n"
|
||||
{{- end}}
|
||||
{{- /* Plain HTTP filter chain with per-domain virtual hosts */}}
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
stat_prefix: inspect_http
|
||||
access_log:
|
||||
- name: envoy.access_loggers.stderr
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
|
||||
log_format:
|
||||
text_format: "[%START_TIME%] http %DOWNSTREAM_REMOTE_ADDRESS% %REQ(:AUTHORITY)% %REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %RESPONSE_CODE% %RESPONSE_FLAGS% %DURATION%ms\n"
|
||||
http_filters:
|
||||
{{.HTTPFiltersSnippet}} - name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
route_config:
|
||||
virtual_hosts:
|
||||
{{- range .VirtualHosts}}
|
||||
- name: {{.Name}}
|
||||
domains: [{{.DomainsStr}}]
|
||||
routes:
|
||||
{{- range .Routes}}
|
||||
- match:
|
||||
prefix: "{{if .PathPrefix}}{{.PathPrefix}}{{else}}/{{end}}"
|
||||
{{- if .Block}}
|
||||
direct_response:
|
||||
status: 403
|
||||
body:
|
||||
filename: "{{$.BlockPagePath}}"
|
||||
{{- else}}
|
||||
route:
|
||||
cluster: original_dst
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
clusters:
|
||||
- name: original_dst
|
||||
type: ORIGINAL_DST
|
||||
lb_policy: CLUSTER_PROVIDED
|
||||
connect_timeout: 10s
|
||||
{{.ExtraClusters}}`))
|
||||
|
||||
// tlsChain represents a TLS filter chain entry for the template.
|
||||
// All TLS chains are passthrough (block decisions happen in Go before envoy).
|
||||
type tlsChain struct {
|
||||
// ServerNames restricts this chain to specific SNIs. Empty is catch-all.
|
||||
ServerNames []string
|
||||
StatPrefix string
|
||||
}
|
||||
|
||||
// envoyRoute represents a single route entry within a virtual host.
|
||||
type envoyRoute struct {
|
||||
// PathPrefix for envoy prefix match. Empty means catch-all "/".
|
||||
PathPrefix string
|
||||
Block bool
|
||||
}
|
||||
|
||||
// virtualHost represents an HTTP virtual host entry for the template.
|
||||
type virtualHost struct {
|
||||
Name string
|
||||
// DomainsStr is pre-formatted for the template: "a", "b".
|
||||
DomainsStr string
|
||||
Routes []envoyRoute
|
||||
}
|
||||
|
||||
type bootstrapData struct {
|
||||
AdminPort uint16
|
||||
ListenPort uint16
|
||||
BlockPagePath string
|
||||
TLSChains []tlsChain
|
||||
VirtualHosts []virtualHost
|
||||
HTTPFiltersSnippet string
|
||||
NetworkFiltersSnippet string
|
||||
ExtraClusters string
|
||||
}
|
||||
|
||||
// generateBootstrap produces the envoy bootstrap YAML from the inspect config.
|
||||
// Translates inspection rules into envoy-native per-SNI and per-domain routing.
|
||||
// blockPagePath is the path to the HTML block page file served by direct_response.
|
||||
func generateBootstrap(config Config, listenPort, adminPort uint16, blockPagePath string) ([]byte, error) {
|
||||
data := bootstrapData{
|
||||
AdminPort: adminPort,
|
||||
BlockPagePath: blockPagePath,
|
||||
ListenPort: listenPort,
|
||||
TLSChains: buildTLSChains(config),
|
||||
VirtualHosts: buildVirtualHosts(config),
|
||||
}
|
||||
|
||||
if config.Envoy != nil && config.Envoy.Snippets != nil {
|
||||
s := config.Envoy.Snippets
|
||||
data.HTTPFiltersSnippet = indentSnippet(s.HTTPFilters, 18)
|
||||
data.NetworkFiltersSnippet = indentSnippet(s.NetworkFilters, 12)
|
||||
data.ExtraClusters = indentSnippet(s.Clusters, 4)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := envoyBootstrapTmpl.Execute(&buf, data); err != nil {
|
||||
return nil, fmt.Errorf("execute bootstrap template: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// buildTLSChains translates inspection rules into envoy TLS filter chains.
|
||||
// Block rules -> per-SNI chain routing to blackhole.
|
||||
// Allow rules (when default=block) -> per-SNI chain routing to original_dst.
|
||||
// Default chain follows DefaultAction.
|
||||
func buildTLSChains(config Config) []tlsChain {
|
||||
// TLS block decisions happen in Go before forwarding to envoy, so we only
|
||||
// generate allow/passthrough chains here. Envoy can't cleanly close a TLS
|
||||
// connection without completing a handshake, so blocked SNIs never reach envoy.
|
||||
var allowed []string
|
||||
|
||||
for _, rule := range config.Rules {
|
||||
if !ruleTouchesProtocol(rule, ProtoHTTPS, ProtoH2) {
|
||||
continue
|
||||
}
|
||||
for _, d := range rule.Domains {
|
||||
sni := d.PunycodeString()
|
||||
if rule.Action == ActionAllow || rule.Action == ActionInspect {
|
||||
allowed = append(allowed, sni)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var chains []tlsChain
|
||||
|
||||
if len(allowed) > 0 && config.DefaultAction == ActionBlock {
|
||||
chains = append(chains, tlsChain{
|
||||
ServerNames: allowed,
|
||||
StatPrefix: "tls_allowed",
|
||||
})
|
||||
}
|
||||
|
||||
// Default catch-all: passthrough (blocked SNIs never arrive here)
|
||||
chains = append(chains, tlsChain{
|
||||
StatPrefix: "tls_default",
|
||||
})
|
||||
|
||||
return chains
|
||||
}
|
||||
|
||||
// buildVirtualHosts translates inspection rules into envoy HTTP virtual hosts.
|
||||
// Groups rules by domain, generates per-path routes within each virtual host.
|
||||
func buildVirtualHosts(config Config) []virtualHost {
|
||||
// Group rules by domain for per-domain virtual hosts.
|
||||
type domainRules struct {
|
||||
domains []string
|
||||
routes []envoyRoute
|
||||
}
|
||||
|
||||
domainRouteMap := make(map[string][]envoyRoute)
|
||||
|
||||
for _, rule := range config.Rules {
|
||||
if !ruleTouchesProtocol(rule, ProtoHTTP, ProtoWebSocket) {
|
||||
continue
|
||||
}
|
||||
isBlock := rule.Action == ActionBlock
|
||||
|
||||
// Rules without domains or paths are handled by the default action.
|
||||
if len(rule.Domains) == 0 && len(rule.Paths) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build routes for this rule's paths
|
||||
var routes []envoyRoute
|
||||
if len(rule.Paths) > 0 {
|
||||
for _, p := range rule.Paths {
|
||||
// Convert our path patterns to envoy prefix match.
|
||||
// Strip trailing * for envoy prefix matching.
|
||||
prefix := strings.TrimSuffix(p, "*")
|
||||
routes = append(routes, envoyRoute{PathPrefix: prefix, Block: isBlock})
|
||||
}
|
||||
} else {
|
||||
routes = append(routes, envoyRoute{Block: isBlock})
|
||||
}
|
||||
|
||||
if len(rule.Domains) > 0 {
|
||||
for _, d := range rule.Domains {
|
||||
host := d.PunycodeString()
|
||||
domainRouteMap[host] = append(domainRouteMap[host], routes...)
|
||||
}
|
||||
} else {
|
||||
// No domain: applies to all, add to default host
|
||||
domainRouteMap["*"] = append(domainRouteMap["*"], routes...)
|
||||
}
|
||||
}
|
||||
|
||||
var hosts []virtualHost
|
||||
idx := 0
|
||||
|
||||
// Per-domain virtual hosts with path routes
|
||||
for domain, routes := range domainRouteMap {
|
||||
if domain == "*" {
|
||||
continue
|
||||
}
|
||||
// Add a catch-all route after path-specific routes.
|
||||
// The catch-all follows the default action.
|
||||
routes = append(routes, envoyRoute{Block: config.DefaultAction == ActionBlock})
|
||||
|
||||
hosts = append(hosts, virtualHost{
|
||||
Name: fmt.Sprintf("domain_%d", idx),
|
||||
DomainsStr: fmt.Sprintf("%q", domain),
|
||||
Routes: routes,
|
||||
})
|
||||
idx++
|
||||
}
|
||||
|
||||
// Default virtual host (catch-all for unmatched domains)
|
||||
defaultRoutes := domainRouteMap["*"]
|
||||
defaultRoutes = append(defaultRoutes, envoyRoute{Block: config.DefaultAction == ActionBlock})
|
||||
hosts = append(hosts, virtualHost{
|
||||
Name: "default",
|
||||
DomainsStr: `"*"`,
|
||||
Routes: defaultRoutes,
|
||||
})
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// ruleTouchesProtocol returns true if the rule's protocol list includes any of the given protocols,
|
||||
// or if the protocol list is empty (matches all).
|
||||
func ruleTouchesProtocol(rule Rule, protos ...ProtoType) bool {
|
||||
if len(rule.Protocols) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, rp := range rule.Protocols {
|
||||
for _, p := range protos {
|
||||
if rp == p {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// indentSnippet prepends each line of the YAML snippet with the given number of spaces.
|
||||
// Returns empty string if snippet is empty.
|
||||
func indentSnippet(snippet string, spaces int) string {
|
||||
if snippet == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := make([]byte, spaces)
|
||||
for i := range prefix {
|
||||
prefix[i] = ' '
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
for i, line := range bytes.Split([]byte(snippet), []byte("\n")) {
|
||||
if i > 0 {
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
if len(line) > 0 {
|
||||
buf.Write(prefix)
|
||||
buf.Write(line)
|
||||
}
|
||||
}
|
||||
buf.WriteByte('\n')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// ValidateSnippets checks that user-provided snippets are safe to inject
|
||||
// into the envoy config. Returns an error describing the first violation found.
|
||||
//
|
||||
// Validation rules:
|
||||
// - Each snippet must be valid YAML (prevents syntax-level injection)
|
||||
// - Snippets must not contain YAML document separators (--- or ...) that could
|
||||
// break out of the indentation context
|
||||
// - Snippets must only contain list items (starting with "- ") at the top level,
|
||||
// matching what envoy expects for filters and clusters
|
||||
func ValidateSnippets(snippets *EnvoySnippets) error {
|
||||
if snippets == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{"http_filters", snippets.HTTPFilters},
|
||||
{"network_filters", snippets.NetworkFilters},
|
||||
{"clusters", snippets.Clusters},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
if f.value == "" {
|
||||
continue
|
||||
}
|
||||
if err := validateSnippetYAML(f.name, f.value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSnippetYAML(name, snippet string) error {
|
||||
// Check for YAML document markers that could break template structure.
|
||||
for _, line := range strings.Split(snippet, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "---" || trimmed == "..." {
|
||||
return fmt.Errorf("snippet %q: YAML document separators (--- or ...) are not allowed", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify it's valid YAML by checking it doesn't cause template execution issues.
|
||||
// We can't import yaml.v3 here without adding a dependency, so we do structural checks.
|
||||
|
||||
// Check for null bytes or control characters that could confuse YAML parsers.
|
||||
for i, b := range []byte(snippet) {
|
||||
if b == 0 {
|
||||
return fmt.Errorf("snippet %q: null byte at position %d", name, i)
|
||||
}
|
||||
if b < 0x09 || (b > 0x0D && b < 0x20 && b != 0x1B) {
|
||||
return fmt.Errorf("snippet %q: control character 0x%02x at position %d", name, b, i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
88
client/inspect/envoy_forward.go
Normal file
88
client/inspect/envoy_forward.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// PROXY protocol v2 constants (RFC 7239 / HAProxy spec)
|
||||
var proxyV2Signature = [12]byte{
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51,
|
||||
0x55, 0x49, 0x54, 0x0A,
|
||||
}
|
||||
|
||||
const (
|
||||
proxyV2VersionCommand = 0x21 // version 2, PROXY command
|
||||
proxyV2FamilyTCP4 = 0x11 // AF_INET, STREAM
|
||||
proxyV2FamilyTCP6 = 0x21 // AF_INET6, STREAM
|
||||
)
|
||||
|
||||
// forwardToEnvoy forwards a connection to the given envoy sidecar via PROXY protocol v2.
|
||||
// The caller provides the envoy manager snapshot to avoid accessing p.envoy without lock.
|
||||
func (p *Proxy) forwardToEnvoy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo, em *envoyManager) error {
|
||||
envoyAddr := em.ListenAddr()
|
||||
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", envoyAddr.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial envoy at %s: %w", envoyAddr, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
p.log.Debugf("close envoy conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := writeProxyV2Header(conn, src.IP, dst); err != nil {
|
||||
return fmt.Errorf("write PROXY v2 header: %w", err)
|
||||
}
|
||||
|
||||
p.log.Tracef("envoy: forwarded %s -> %s via PROXY v2", src.IP, dst)
|
||||
|
||||
return relay(ctx, pconn, conn)
|
||||
}
|
||||
|
||||
// writeProxyV2Header writes a PROXY protocol v2 header to w.
|
||||
// The header encodes the original source IP and the destination address:port.
|
||||
func writeProxyV2Header(w net.Conn, srcIP netip.Addr, dst netip.AddrPort) error {
|
||||
srcIP = srcIP.Unmap()
|
||||
dstIP := dst.Addr().Unmap()
|
||||
|
||||
var (
|
||||
family byte
|
||||
addrs []byte
|
||||
)
|
||||
|
||||
if srcIP.Is4() && dstIP.Is4() {
|
||||
family = proxyV2FamilyTCP4
|
||||
s4 := srcIP.As4()
|
||||
d4 := dstIP.As4()
|
||||
addrs = make([]byte, 12) // 4+4+2+2
|
||||
copy(addrs[0:4], s4[:])
|
||||
copy(addrs[4:8], d4[:])
|
||||
binary.BigEndian.PutUint16(addrs[8:10], 0) // src port unknown
|
||||
binary.BigEndian.PutUint16(addrs[10:12], dst.Port())
|
||||
} else {
|
||||
family = proxyV2FamilyTCP6
|
||||
s16 := srcIP.As16()
|
||||
d16 := dstIP.As16()
|
||||
addrs = make([]byte, 36) // 16+16+2+2
|
||||
copy(addrs[0:16], s16[:])
|
||||
copy(addrs[16:32], d16[:])
|
||||
binary.BigEndian.PutUint16(addrs[32:34], 0) // src port unknown
|
||||
binary.BigEndian.PutUint16(addrs[34:36], dst.Port())
|
||||
}
|
||||
|
||||
// Header: signature(12) + ver_cmd(1) + family(1) + len(2) + addrs
|
||||
header := make([]byte, 16+len(addrs))
|
||||
copy(header[0:12], proxyV2Signature[:])
|
||||
header[12] = proxyV2VersionCommand
|
||||
header[13] = family
|
||||
binary.BigEndian.PutUint16(header[14:16], uint16(len(addrs)))
|
||||
copy(header[16:], addrs)
|
||||
|
||||
_, err := w.Write(header)
|
||||
return err
|
||||
}
|
||||
13
client/inspect/envoy_signal.go
Normal file
13
client/inspect/envoy_signal.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows
|
||||
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// signalReload sends SIGHUP to the envoy process to trigger config reload.
|
||||
func signalReload(p *os.Process) error {
|
||||
return p.Signal(syscall.SIGHUP)
|
||||
}
|
||||
13
client/inspect/envoy_signal_windows.go
Normal file
13
client/inspect/envoy_signal_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// signalReload is not supported on Windows. Envoy must be restarted.
|
||||
func signalReload(_ *os.Process) error {
|
||||
return fmt.Errorf("envoy config reload via signal not supported on Windows")
|
||||
}
|
||||
229
client/inspect/external.go
Normal file
229
client/inspect/external.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
externalDialTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// handleExternal forwards the connection to an external proxy.
|
||||
// For TLS connections, it uses HTTP CONNECT to tunnel through the proxy.
|
||||
// For HTTP connections, it rewrites the request to use the proxy.
|
||||
func (p *Proxy) handleExternal(ctx context.Context, pconn *peekConn, dst netip.AddrPort) error {
|
||||
p.mu.RLock()
|
||||
proxyURL := p.config.ExternalURL
|
||||
p.mu.RUnlock()
|
||||
|
||||
if proxyURL == nil {
|
||||
return fmt.Errorf("external proxy URL not configured")
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
return p.externalHTTPProxy(ctx, pconn, dst, proxyURL)
|
||||
case "socks5":
|
||||
return p.externalSOCKS5(ctx, pconn, dst, proxyURL)
|
||||
default:
|
||||
return fmt.Errorf("unsupported external proxy scheme: %s", proxyURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// externalHTTPProxy tunnels through an HTTP proxy using CONNECT.
|
||||
func (p *Proxy) externalHTTPProxy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
|
||||
proxyAddr := proxyURL.Host
|
||||
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
|
||||
proxyAddr = net.JoinHostPort(proxyAddr, "8080")
|
||||
}
|
||||
|
||||
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial external proxy %s: %w", proxyAddr, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
p.log.Debugf("close external proxy conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", dst.String(), dst.String())
|
||||
if proxyURL.User != nil {
|
||||
connectReq += "Proxy-Authorization: Basic " + basicAuth(proxyURL.User) + "\r\n"
|
||||
}
|
||||
connectReq += "\r\n"
|
||||
|
||||
if _, err := io.WriteString(proxyConn, connectReq); err != nil {
|
||||
return fmt.Errorf("send CONNECT to proxy: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read CONNECT response: %w", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close CONNECT resp body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
|
||||
}
|
||||
|
||||
return relay(ctx, pconn, proxyConn)
|
||||
}
|
||||
|
||||
// externalSOCKS5 tunnels through a SOCKS5 proxy.
|
||||
func (p *Proxy) externalSOCKS5(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
|
||||
proxyAddr := proxyURL.Host
|
||||
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
|
||||
proxyAddr = net.JoinHostPort(proxyAddr, "1080")
|
||||
}
|
||||
|
||||
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial SOCKS5 proxy %s: %w", proxyAddr, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
p.log.Debugf("close SOCKS5 proxy conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := socks5Handshake(proxyConn, dst, proxyURL.User); err != nil {
|
||||
return fmt.Errorf("SOCKS5 handshake: %w", err)
|
||||
}
|
||||
|
||||
return relay(ctx, pconn, proxyConn)
|
||||
}
|
||||
|
||||
// socks5Handshake performs the SOCKS5 handshake to connect through the proxy.
|
||||
func socks5Handshake(conn net.Conn, dst netip.AddrPort, userinfo *url.Userinfo) error {
|
||||
needAuth := userinfo != nil
|
||||
|
||||
// Greeting
|
||||
var methods []byte
|
||||
if needAuth {
|
||||
methods = []byte{0x00, 0x02} // no auth, username/password
|
||||
} else {
|
||||
methods = []byte{0x00} // no auth
|
||||
}
|
||||
greeting := append([]byte{0x05, byte(len(methods))}, methods...)
|
||||
if _, err := conn.Write(greeting); err != nil {
|
||||
return fmt.Errorf("send greeting: %w", err)
|
||||
}
|
||||
|
||||
// Server method selection
|
||||
var methodResp [2]byte
|
||||
if _, err := io.ReadFull(conn, methodResp[:]); err != nil {
|
||||
return fmt.Errorf("read method selection: %w", err)
|
||||
}
|
||||
if methodResp[0] != 0x05 {
|
||||
return fmt.Errorf("unexpected SOCKS version: %d", methodResp[0])
|
||||
}
|
||||
|
||||
// Handle authentication if selected
|
||||
if methodResp[1] == 0x02 {
|
||||
if err := socks5Auth(conn, userinfo); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if methodResp[1] != 0x00 {
|
||||
return fmt.Errorf("unsupported SOCKS5 auth method: %d", methodResp[1])
|
||||
}
|
||||
|
||||
// Connection request
|
||||
addr := dst.Addr()
|
||||
var addrBytes []byte
|
||||
if addr.Is4() {
|
||||
a4 := addr.As4()
|
||||
addrBytes = append([]byte{0x01}, a4[:]...) // IPv4
|
||||
} else {
|
||||
a16 := addr.As16()
|
||||
addrBytes = append([]byte{0x04}, a16[:]...) // IPv6
|
||||
}
|
||||
|
||||
port := dst.Port()
|
||||
connectReq := append([]byte{0x05, 0x01, 0x00}, addrBytes...)
|
||||
connectReq = append(connectReq, byte(port>>8), byte(port))
|
||||
|
||||
if _, err := conn.Write(connectReq); err != nil {
|
||||
return fmt.Errorf("send connect request: %w", err)
|
||||
}
|
||||
|
||||
// Read response (minimum 10 bytes for IPv4)
|
||||
var respHeader [4]byte
|
||||
if _, err := io.ReadFull(conn, respHeader[:]); err != nil {
|
||||
return fmt.Errorf("read connect response: %w", err)
|
||||
}
|
||||
if respHeader[1] != 0x00 {
|
||||
return fmt.Errorf("SOCKS5 connect failed: status %d", respHeader[1])
|
||||
}
|
||||
|
||||
// Skip bound address
|
||||
switch respHeader[3] {
|
||||
case 0x01: // IPv4
|
||||
var skip [4 + 2]byte
|
||||
if _, err := io.ReadFull(conn, skip[:]); err != nil {
|
||||
return fmt.Errorf("read SOCKS5 bound IPv4 address: %w", err)
|
||||
}
|
||||
case 0x04: // IPv6
|
||||
var skip [16 + 2]byte
|
||||
if _, err := io.ReadFull(conn, skip[:]); err != nil {
|
||||
return fmt.Errorf("read SOCKS5 bound IPv6 address: %w", err)
|
||||
}
|
||||
case 0x03: // Domain
|
||||
var dLen [1]byte
|
||||
if _, err := io.ReadFull(conn, dLen[:]); err != nil {
|
||||
return fmt.Errorf("read domain length: %w", err)
|
||||
}
|
||||
skip := make([]byte, int(dLen[0])+2)
|
||||
if _, err := io.ReadFull(conn, skip); err != nil {
|
||||
return fmt.Errorf("read SOCKS5 bound domain address: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func socks5Auth(conn net.Conn, userinfo *url.Userinfo) error {
|
||||
if userinfo == nil {
|
||||
return fmt.Errorf("SOCKS5 auth required but no credentials provided")
|
||||
}
|
||||
|
||||
user := userinfo.Username()
|
||||
pass, _ := userinfo.Password()
|
||||
|
||||
// Username/password auth (RFC 1929)
|
||||
auth := []byte{0x01, byte(len(user))}
|
||||
auth = append(auth, []byte(user)...)
|
||||
auth = append(auth, byte(len(pass)))
|
||||
auth = append(auth, []byte(pass)...)
|
||||
|
||||
if _, err := conn.Write(auth); err != nil {
|
||||
return fmt.Errorf("send auth: %w", err)
|
||||
}
|
||||
|
||||
var resp [2]byte
|
||||
if _, err := io.ReadFull(conn, resp[:]); err != nil {
|
||||
return fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
if resp[1] != 0x00 {
|
||||
return fmt.Errorf("SOCKS5 auth failed: status %d", resp[1])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func basicAuth(userinfo *url.Userinfo) string {
|
||||
user := userinfo.Username()
|
||||
pass, _ := userinfo.Password()
|
||||
return base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
|
||||
}
|
||||
532
client/inspect/http.go
Normal file
532
client/inspect/http.go
Normal file
@@ -0,0 +1,532 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
headerUpgrade = "Upgrade"
|
||||
valueWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// inspectHTTP runs the HTTP inspection pipeline on decrypted traffic.
|
||||
// It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy),
|
||||
// and WebSocket upgrade detection.
|
||||
func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error {
|
||||
if proto == "h2" {
|
||||
return p.inspectH2(ctx, client, remote, dst, sni, src)
|
||||
}
|
||||
return p.inspectH1(ctx, client, remote, dst, sni, src)
|
||||
}
|
||||
|
||||
// inspectH1 handles HTTP/1.1 request-response inspection in a loop.
|
||||
func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
|
||||
clientReader := bufio.NewReader(client)
|
||||
remoteReader := bufio.NewReader(remote)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Set idle timeout between requests to prevent connection hogging.
|
||||
if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return fmt.Errorf("set idle deadline: %w", err)
|
||||
}
|
||||
req, err := http.ReadRequest(clientReader)
|
||||
if err != nil {
|
||||
if isClosedErr(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read HTTP request: %w", err)
|
||||
}
|
||||
if err := client.SetReadDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Re-evaluate rules based on Host header if SNI was empty
|
||||
host := hostFromRequest(req, sni)
|
||||
|
||||
// Domain fronting: Host header doesn't match TLS SNI
|
||||
if isDomainFronting(req, sni) {
|
||||
p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString())
|
||||
writeBlockResponse(client, req, host)
|
||||
return ErrBlocked
|
||||
}
|
||||
|
||||
proto := ProtoHTTP
|
||||
if isWebSocketUpgrade(req) {
|
||||
proto = ProtoWebSocket
|
||||
}
|
||||
action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path)
|
||||
if action == ActionBlock {
|
||||
p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
|
||||
writeBlockResponse(client, req, host)
|
||||
return ErrBlocked
|
||||
}
|
||||
p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action)
|
||||
|
||||
// ICAP REQMOD: send request for inspection.
|
||||
// Snapshot ICAP client under lock to avoid use-after-close races.
|
||||
p.mu.RLock()
|
||||
icap := p.icap
|
||||
p.mu.RUnlock()
|
||||
if icap != nil {
|
||||
modified, err := icap.ReqMod(req)
|
||||
if err != nil {
|
||||
p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
|
||||
// Fail-closed: block on ICAP error
|
||||
writeBlockResponse(client, req, host)
|
||||
return fmt.Errorf("ICAP REQMOD: %w", err)
|
||||
}
|
||||
req = modified
|
||||
}
|
||||
|
||||
if isWebSocketUpgrade(req) {
|
||||
return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader)
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(req.Header)
|
||||
|
||||
if err := req.Write(remote); err != nil {
|
||||
return fmt.Errorf("forward request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(remoteReader, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read HTTP response: %w", err)
|
||||
}
|
||||
|
||||
// ICAP RESPMOD: send response for inspection
|
||||
if icap != nil {
|
||||
modified, err := icap.RespMod(req, resp)
|
||||
if err != nil {
|
||||
p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close resp body: %v", err)
|
||||
}
|
||||
writeBlockResponse(client, req, host)
|
||||
return fmt.Errorf("ICAP RESPMOD: %w", err)
|
||||
}
|
||||
resp = modified
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(resp.Header)
|
||||
|
||||
if err := resp.Write(client); err != nil {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
p.log.Debugf("close resp body: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("forward response: %w", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close resp body: %v", err)
|
||||
}
|
||||
|
||||
// Connection: close means we're done
|
||||
if resp.Close || req.Close {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inspectH2 proxies HTTP/2 traffic using Go's http stack.
|
||||
// Client and remote are already-established TLS connections with h2 negotiated.
|
||||
func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
|
||||
// For h2 MITM inspection, we use a local http.Server reading from the client
|
||||
// connection and an http.Transport writing to the remote connection.
|
||||
//
|
||||
// The transport is configured to use the existing TLS connection to the
|
||||
// real server. The handler inspects each request/response pair.
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return remote, nil
|
||||
},
|
||||
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return remote, nil
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
handler := &h2InspectionHandler{
|
||||
proxy: p,
|
||||
transport: transport,
|
||||
dst: dst,
|
||||
sni: sni,
|
||||
src: src,
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
// Serve the single client connection.
|
||||
// ServeConn blocks until the connection is done.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
// http.Server doesn't have a direct ServeConn for h2,
|
||||
// so we use Serve with a single-connection listener.
|
||||
ln := &singleConnListener{conn: client}
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := server.Close(); err != nil {
|
||||
p.log.Debugf("close h2 server: %v", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
if err == http.ErrServerClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// h2InspectionHandler inspects each HTTP/2 request/response pair.
|
||||
type h2InspectionHandler struct {
|
||||
proxy *Proxy
|
||||
transport http.RoundTripper
|
||||
dst netip.AddrPort
|
||||
sni domain.Domain
|
||||
src SourceInfo
|
||||
}
|
||||
|
||||
func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
host := hostFromRequest(req, h.sni)
|
||||
|
||||
if isDomainFronting(req, h.sni) {
|
||||
h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString())
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
|
||||
action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path)
|
||||
if action == ActionBlock {
|
||||
h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
|
||||
// ICAP REQMOD
|
||||
if h.proxy.icap != nil {
|
||||
modified, err := h.proxy.icap.ReqMod(req)
|
||||
if err != nil {
|
||||
h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
req = modified
|
||||
}
|
||||
|
||||
// Forward to upstream
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = h.sni.PunycodeString()
|
||||
req.RequestURI = ""
|
||||
|
||||
resp, err := h.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
h.proxy.log.Debugf("close h2 resp body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// ICAP RESPMOD
|
||||
if h.proxy.icap != nil {
|
||||
modified, err := h.proxy.icap.RespMod(req, resp)
|
||||
if err != nil {
|
||||
h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
resp = modified
|
||||
}
|
||||
|
||||
// Copy response headers and body
|
||||
for k, vals := range resp.Header {
|
||||
for _, v := range vals {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||
h.proxy.log.Debugf("h2 response copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally.
|
||||
func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error {
|
||||
if err := req.Write(remote); err != nil {
|
||||
return fmt.Errorf("forward WebSocket upgrade: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(remoteReader, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read WebSocket upgrade response: %w", err)
|
||||
}
|
||||
|
||||
if err := resp.Write(client); err != nil {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
p.log.Debugf("close ws resp body: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("forward WebSocket upgrade response: %w", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close ws resp body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
p.log.Tracef("allow: WebSocket upgrade for %s", req.Host)
|
||||
|
||||
// Relay WebSocket frames bidirectionally.
|
||||
// clientReader/remoteReader may have buffered data.
|
||||
clientConn := mergeReadWriter(clientReader, client)
|
||||
remoteConn := mergeReadWriter(remoteReader, remote)
|
||||
|
||||
return relayRW(ctx, clientConn, remoteConn)
|
||||
}
|
||||
|
||||
// hostFromRequest extracts a domain.Domain from the HTTP request Host header,
|
||||
// falling back to the SNI if Host is empty or an IP.
|
||||
func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
// Strip port if present
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
// If it's an IP address, use the SNI fallback
|
||||
if _, err := netip.ParseAddr(host); err == nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
d, err := domain.FromString(host)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// isDomainFronting detects domain fronting: the Host header doesn't match the
|
||||
// SNI used during the TLS handshake. Only meaningful when SNI is non-empty
|
||||
// (i.e., we're in MITM mode and know the original SNI).
|
||||
func isDomainFronting(req *http.Request, sni domain.Domain) bool {
|
||||
if sni == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
host := hostFromRequest(req, "")
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Host should match SNI or be a subdomain of SNI
|
||||
if host == sni {
|
||||
return false
|
||||
}
|
||||
|
||||
// Allow www.example.com when SNI is example.com
|
||||
sniStr := sni.PunycodeString()
|
||||
hostStr := host.PunycodeString()
|
||||
if strings.HasSuffix(hostStr, "."+sniStr) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket)
|
||||
}
|
||||
|
||||
// writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path).
|
||||
func writeBlockPage(w http.ResponseWriter, host domain.Domain) {
|
||||
hostname := host.PunycodeString()
|
||||
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
io.WriteString(w, body)
|
||||
}
|
||||
|
||||
func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) {
|
||||
hostname := host.PunycodeString()
|
||||
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
ContentLength: int64(len(body)),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
resp.Header.Set("Connection", "close")
|
||||
resp.Header.Set("Cache-Control", "no-store")
|
||||
_ = resp.Write(w)
|
||||
}
|
||||
|
||||
// blockPageHTML is the self-contained HTML block page.
|
||||
// Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain.
|
||||
const blockPageHTML = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>Blocked - %s</title>
|
||||
<style>
|
||||
*{margin:0;padding:0;box-sizing:border-box}
|
||||
body{background:#181a1d;color:#d1d5db;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;min-height:100vh;display:flex;align-items:center;justify-content:center}
|
||||
.c{text-align:center;max-width:460px;padding:2rem}
|
||||
.shield{width:56px;height:56px;margin:0 auto 1.5rem;border-radius:16px;background:#2b2f33;display:flex;align-items:center;justify-content:center}
|
||||
.shield svg{width:28px;height:28px;color:#f68330}
|
||||
.code{font-size:.8rem;font-weight:500;color:#f68330;font-family:ui-monospace,monospace;letter-spacing:.05em;margin-bottom:.5rem}
|
||||
h1{font-size:1.5rem;font-weight:600;color:#f4f4f5;margin-bottom:.5rem}
|
||||
p{font-size:.95rem;line-height:1.5;color:#9ca3af;margin-bottom:1.75rem}
|
||||
.domain{display:inline-block;background:#25282d;border:1px solid #32363d;border-radius:6px;padding:.15rem .5rem;font-family:ui-monospace,monospace;font-size:.85rem;color:#d1d5db}
|
||||
.footer{font-size:.7rem;color:#6b7280;margin-top:2rem;letter-spacing:.03em}
|
||||
.footer a{color:#6b7280;text-decoration:none}
|
||||
.footer a:hover{color:#9ca3af}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="c">
|
||||
<div class="shield"><svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m0-10.036A11.959 11.959 0 0 1 3.598 6 11.99 11.99 0 0 0 3 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751A11.96 11.96 0 0 0 12 3.714Z"/></svg></div>
|
||||
<div class="code">403 BLOCKED</div>
|
||||
<h1>Access Denied</h1>
|
||||
<p>This connection to <span class="domain">%s</span> has been blocked by your organization's network policy.</p>
|
||||
<div class="footer">Protected by <a href="https://netbird.io" target="_blank" rel="noopener">NetBird</a></div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// singleConnListener is a net.Listener that yields a single connection.
|
||||
type singleConnListener struct {
|
||||
conn net.Conn
|
||||
once sync.Once
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
func (l *singleConnListener) Accept() (net.Conn, error) {
|
||||
var accepted bool
|
||||
l.once.Do(func() {
|
||||
l.ch = make(chan struct{})
|
||||
accepted = true
|
||||
})
|
||||
if accepted {
|
||||
return l.conn, nil
|
||||
}
|
||||
// Block until Close
|
||||
<-l.ch
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (l *singleConnListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
l.ch = make(chan struct{})
|
||||
})
|
||||
select {
|
||||
case <-l.ch:
|
||||
default:
|
||||
close(l.ch)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *singleConnListener) Addr() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
type readWriter struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
|
||||
return &readWriter{Reader: r, Writer: w}
|
||||
}
|
||||
|
||||
// relayRW copies data bidirectionally between two ReadWriters.
|
||||
func relayRW(ctx context.Context, a, b io.ReadWriter) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(b, a)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(a, b)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
var firstErr error
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil && firstErr == nil {
|
||||
if !isClosedErr(err) {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// hopByHopHeaders are HTTP/1.1 headers that apply to a single connection
|
||||
// and must not be forwarded by a proxy (RFC 7230, Section 6.1).
|
||||
var hopByHopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"TE",
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders strips hop-by-hop headers from h.
|
||||
// Also removes headers listed in the Connection header value.
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
// First, remove any headers named in the Connection header
|
||||
for _, connHeader := range h["Connection"] {
|
||||
for _, name := range strings.Split(connHeader, ",") {
|
||||
h.Del(strings.TrimSpace(name))
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range hopByHopHeaders {
|
||||
h.Del(name)
|
||||
}
|
||||
}
|
||||
479
client/inspect/icap.go
Normal file
479
client/inspect/icap.go
Normal file
@@ -0,0 +1,479 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
icapVersion = "ICAP/1.0"
|
||||
icapDefaultPort = "1344"
|
||||
icapConnTimeout = 30 * time.Second
|
||||
icapRWTimeout = 60 * time.Second
|
||||
icapMaxPoolSize = 8
|
||||
icapIdleTimeout = 60 * time.Second
|
||||
icapMaxRespSize = 4 * 1024 * 1024 // 4 MB
|
||||
)
|
||||
|
||||
// ICAPClient implements an ICAP (RFC 3507) client with persistent connection pooling.
|
||||
type ICAPClient struct {
|
||||
reqModURL *url.URL
|
||||
respModURL *url.URL
|
||||
pool chan *icapConn
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
maxPool int
|
||||
}
|
||||
|
||||
type icapConn struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
lastUse time.Time
|
||||
}
|
||||
|
||||
// NewICAPClient creates an ICAP client. Either or both URLs may be nil
|
||||
// to disable that mode.
|
||||
func NewICAPClient(logger *log.Entry, cfg *ICAPConfig) *ICAPClient {
|
||||
maxPool := cfg.MaxConnections
|
||||
if maxPool <= 0 {
|
||||
maxPool = icapMaxPoolSize
|
||||
}
|
||||
|
||||
return &ICAPClient{
|
||||
reqModURL: cfg.ReqModURL,
|
||||
respModURL: cfg.RespModURL,
|
||||
pool: make(chan *icapConn, maxPool),
|
||||
log: logger,
|
||||
maxPool: maxPool,
|
||||
}
|
||||
}
|
||||
|
||||
// ReqMod sends an HTTP request to the ICAP REQMOD service for inspection.
|
||||
// Returns the (possibly modified) request, or the original if ICAP returns 204.
|
||||
// Returns nil, nil if REQMOD is not configured.
|
||||
func (c *ICAPClient) ReqMod(req *http.Request) (*http.Request, error) {
|
||||
if c.reqModURL == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
var reqBuf bytes.Buffer
|
||||
if err := req.Write(&reqBuf); err != nil {
|
||||
return nil, fmt.Errorf("serialize request: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := c.send("REQMOD", c.reqModURL, reqBuf.Bytes(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if respBody == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
modified, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(respBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP modified request: %w", err)
|
||||
}
|
||||
return modified, nil
|
||||
}
|
||||
|
||||
// RespMod sends an HTTP response to the ICAP RESPMOD service for inspection.
|
||||
// Returns the (possibly modified) response, or the original if ICAP returns 204.
|
||||
// Returns nil, nil if RESPMOD is not configured.
|
||||
func (c *ICAPClient) RespMod(req *http.Request, resp *http.Response) (*http.Response, error) {
|
||||
if c.respModURL == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var reqBuf bytes.Buffer
|
||||
if err := req.Write(&reqBuf); err != nil {
|
||||
return nil, fmt.Errorf("serialize request: %w", err)
|
||||
}
|
||||
|
||||
var respBuf bytes.Buffer
|
||||
if err := resp.Write(&respBuf); err != nil {
|
||||
return nil, fmt.Errorf("serialize response: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := c.send("RESPMOD", c.respModURL, reqBuf.Bytes(), respBuf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if respBody == nil {
|
||||
// 204 No Content: ICAP server didn't modify the response.
|
||||
// Reconstruct from the buffered copy since resp.Body was consumed by Write.
|
||||
reconstructed, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBuf.Bytes())), req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reconstruct response after ICAP 204: %w", err)
|
||||
}
|
||||
return reconstructed, nil
|
||||
}
|
||||
|
||||
modified, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBody)), req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP modified response: %w", err)
|
||||
}
|
||||
return modified, nil
|
||||
}
|
||||
|
||||
// Close drains and closes all pooled connections.
|
||||
func (c *ICAPClient) Close() {
|
||||
close(c.pool)
|
||||
for ic := range c.pool {
|
||||
if err := ic.conn.Close(); err != nil {
|
||||
c.log.Debugf("close ICAP connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// send executes an ICAP request and returns the encapsulated body from the response.
|
||||
// Returns nil body for 204 No Content (no modification).
|
||||
// Retries once on stale pooled connection (EOF on read).
|
||||
func (c *ICAPClient) send(method string, serviceURL *url.URL, reqData, respData []byte) ([]byte, error) {
|
||||
statusCode, headers, body, err := c.trySend(method, serviceURL, reqData, respData)
|
||||
if err != nil && isStaleConnErr(err) {
|
||||
// Retry once with a fresh connection (stale pool entry).
|
||||
c.log.Debugf("ICAP %s: retrying after stale connection: %v", method, err)
|
||||
statusCode, headers, body, err = c.trySend(method, serviceURL, reqData, respData)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 204:
|
||||
return nil, nil
|
||||
case 200:
|
||||
return body, nil
|
||||
default:
|
||||
c.log.Debugf("ICAP %s returned status %d, headers: %v", method, statusCode, headers)
|
||||
return nil, fmt.Errorf("ICAP %s: status %d", method, statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ICAPClient) trySend(method string, serviceURL *url.URL, reqData, respData []byte) (int, textproto.MIMEHeader, []byte, error) {
|
||||
ic, err := c.getConn(serviceURL)
|
||||
if err != nil {
|
||||
return 0, nil, nil, fmt.Errorf("get ICAP connection: %w", err)
|
||||
}
|
||||
|
||||
if err := c.writeRequest(ic, method, serviceURL, reqData, respData); err != nil {
|
||||
if closeErr := ic.conn.Close(); closeErr != nil {
|
||||
c.log.Debugf("close ICAP conn after write error: %v", closeErr)
|
||||
}
|
||||
return 0, nil, nil, fmt.Errorf("write ICAP %s: %w", method, err)
|
||||
}
|
||||
|
||||
statusCode, headers, body, err := c.readResponse(ic)
|
||||
if err != nil {
|
||||
if closeErr := ic.conn.Close(); closeErr != nil {
|
||||
c.log.Debugf("close ICAP conn after read error: %v", closeErr)
|
||||
}
|
||||
return 0, nil, nil, fmt.Errorf("read ICAP response: %w", err)
|
||||
}
|
||||
|
||||
c.putConn(ic)
|
||||
return statusCode, headers, body, nil
|
||||
}
|
||||
|
||||
func isStaleConnErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := err.Error()
|
||||
return strings.Contains(s, "EOF") || strings.Contains(s, "broken pipe") || strings.Contains(s, "connection reset")
|
||||
}
|
||||
|
||||
func (c *ICAPClient) writeRequest(ic *icapConn, method string, serviceURL *url.URL, reqData, respData []byte) error {
|
||||
if err := ic.conn.SetWriteDeadline(time.Now().Add(icapRWTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
// For RESPMOD, split the serialized HTTP response into headers and body.
|
||||
// The body must be sent chunked per RFC 3507.
|
||||
var respHdr, respBody []byte
|
||||
if respData != nil {
|
||||
if idx := bytes.Index(respData, []byte("\r\n\r\n")); idx >= 0 {
|
||||
respHdr = respData[:idx+4] // include the \r\n\r\n separator
|
||||
respBody = respData[idx+4:]
|
||||
} else {
|
||||
respHdr = respData
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Request line
|
||||
fmt.Fprintf(&buf, "%s %s %s\r\n", method, serviceURL.String(), icapVersion)
|
||||
|
||||
// Headers
|
||||
host := serviceURL.Host
|
||||
fmt.Fprintf(&buf, "Host: %s\r\n", host)
|
||||
fmt.Fprintf(&buf, "Connection: keep-alive\r\n")
|
||||
fmt.Fprintf(&buf, "Allow: 204\r\n")
|
||||
|
||||
// Build Encapsulated header
|
||||
offset := 0
|
||||
var encapParts []string
|
||||
if reqData != nil {
|
||||
encapParts = append(encapParts, fmt.Sprintf("req-hdr=%d", offset))
|
||||
offset += len(reqData)
|
||||
}
|
||||
if respHdr != nil {
|
||||
encapParts = append(encapParts, fmt.Sprintf("res-hdr=%d", offset))
|
||||
offset += len(respHdr)
|
||||
}
|
||||
if len(respBody) > 0 {
|
||||
encapParts = append(encapParts, fmt.Sprintf("res-body=%d", offset))
|
||||
} else {
|
||||
encapParts = append(encapParts, fmt.Sprintf("null-body=%d", offset))
|
||||
}
|
||||
fmt.Fprintf(&buf, "Encapsulated: %s\r\n", strings.Join(encapParts, ", "))
|
||||
fmt.Fprintf(&buf, "\r\n")
|
||||
|
||||
// Encapsulated sections
|
||||
if reqData != nil {
|
||||
buf.Write(reqData)
|
||||
}
|
||||
if respHdr != nil {
|
||||
buf.Write(respHdr)
|
||||
}
|
||||
// Body in chunked encoding (only when there is an actual body section).
|
||||
// Per RFC 3507 Section 4.4.1, null-body must not include any entity data.
|
||||
if len(respBody) > 0 {
|
||||
fmt.Fprintf(&buf, "%x\r\n", len(respBody))
|
||||
buf.Write(respBody)
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString("0\r\n\r\n")
|
||||
}
|
||||
|
||||
_, err := ic.conn.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ICAPClient) readResponse(ic *icapConn) (int, textproto.MIMEHeader, []byte, error) {
|
||||
if err := ic.conn.SetReadDeadline(time.Now().Add(icapRWTimeout)); err != nil {
|
||||
return 0, nil, nil, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
tp := textproto.NewReader(ic.reader)
|
||||
|
||||
// Status line: "ICAP/1.0 200 OK"
|
||||
statusLine, err := tp.ReadLine()
|
||||
if err != nil {
|
||||
return 0, nil, nil, fmt.Errorf("read status line: %w", err)
|
||||
}
|
||||
|
||||
statusCode, err := parseICAPStatus(statusLine)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
|
||||
// Headers
|
||||
headers, err := tp.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
return statusCode, nil, nil, fmt.Errorf("read ICAP headers: %w", err)
|
||||
}
|
||||
|
||||
if statusCode == 204 {
|
||||
return statusCode, headers, nil, nil
|
||||
}
|
||||
|
||||
// Read encapsulated body based on Encapsulated header
|
||||
body, err := c.readEncapsulatedBody(ic.reader, headers)
|
||||
if err != nil {
|
||||
return statusCode, headers, nil, fmt.Errorf("read encapsulated body: %w", err)
|
||||
}
|
||||
|
||||
return statusCode, headers, body, nil
|
||||
}
|
||||
|
||||
func (c *ICAPClient) readEncapsulatedBody(r *bufio.Reader, headers textproto.MIMEHeader) ([]byte, error) {
|
||||
encap := headers.Get("Encapsulated")
|
||||
if encap == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Find the body offset from the Encapsulated header.
|
||||
// The last section with a non-zero offset is the body.
|
||||
// Read everything from the reader as the encapsulated content.
|
||||
var totalSize int
|
||||
parts := strings.Split(encap, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx < 0 {
|
||||
continue
|
||||
}
|
||||
offset, err := strconv.Atoi(strings.TrimSpace(part[eqIdx+1:]))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if offset > totalSize {
|
||||
totalSize = offset
|
||||
}
|
||||
}
|
||||
|
||||
// Read all available encapsulated data (headers + body)
|
||||
// The body section uses chunked encoding per RFC 3507
|
||||
var buf bytes.Buffer
|
||||
if totalSize > 0 {
|
||||
// Read the header sections (everything before the body offset)
|
||||
headerBytes := make([]byte, totalSize)
|
||||
if _, err := io.ReadFull(r, headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("read encapsulated headers: %w", err)
|
||||
}
|
||||
buf.Write(headerBytes)
|
||||
}
|
||||
|
||||
// Read chunked body
|
||||
chunked := newChunkedReader(r)
|
||||
body, err := io.ReadAll(io.LimitReader(chunked, icapMaxRespSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read chunked body: %w", err)
|
||||
}
|
||||
buf.Write(body)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (c *ICAPClient) getConn(serviceURL *url.URL) (*icapConn, error) {
|
||||
// Try to get a pooled connection
|
||||
for {
|
||||
select {
|
||||
case ic := <-c.pool:
|
||||
if time.Since(ic.lastUse) > icapIdleTimeout {
|
||||
if err := ic.conn.Close(); err != nil {
|
||||
c.log.Debugf("close idle ICAP connection: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return ic, nil
|
||||
default:
|
||||
return c.dialConn(serviceURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ICAPClient) putConn(ic *icapConn) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
ic.lastUse = time.Now()
|
||||
select {
|
||||
case c.pool <- ic:
|
||||
default:
|
||||
// Pool full, close connection.
|
||||
if err := ic.conn.Close(); err != nil {
|
||||
c.log.Debugf("close excess ICAP connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ICAPClient) dialConn(serviceURL *url.URL) (*icapConn, error) {
|
||||
host := serviceURL.Host
|
||||
if _, _, err := net.SplitHostPort(host); err != nil {
|
||||
host = net.JoinHostPort(host, icapDefaultPort)
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", host, icapConnTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial ICAP %s: %w", host, err)
|
||||
}
|
||||
|
||||
return &icapConn{
|
||||
conn: conn,
|
||||
reader: bufio.NewReader(conn),
|
||||
lastUse: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseICAPStatus(line string) (int, error) {
|
||||
// "ICAP/1.0 200 OK"
|
||||
parts := strings.SplitN(line, " ", 3)
|
||||
if len(parts) < 2 {
|
||||
return 0, fmt.Errorf("malformed ICAP status line: %q", line)
|
||||
}
|
||||
code, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse ICAP status code %q: %w", parts[1], err)
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// chunkedReader reads ICAP chunked encoding (same as HTTP chunked, terminated by "0\r\n\r\n").
|
||||
type chunkedReader struct {
|
||||
r *bufio.Reader
|
||||
remaining int
|
||||
done bool
|
||||
}
|
||||
|
||||
func newChunkedReader(r *bufio.Reader) *chunkedReader {
|
||||
return &chunkedReader{r: r}
|
||||
}
|
||||
|
||||
func (cr *chunkedReader) Read(p []byte) (int, error) {
|
||||
if cr.done {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if cr.remaining == 0 {
|
||||
// Read chunk size line
|
||||
line, err := cr.r.ReadString('\n')
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Strip any chunk extensions
|
||||
if idx := strings.Index(line, ";"); idx >= 0 {
|
||||
line = line[:idx]
|
||||
}
|
||||
|
||||
size, err := strconv.ParseInt(line, 16, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse chunk size %q: %w", line, err)
|
||||
}
|
||||
|
||||
if size == 0 {
|
||||
cr.done = true
|
||||
// Consume trailing \r\n
|
||||
_, _ = cr.r.ReadString('\n')
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if size < 0 || size > icapMaxRespSize {
|
||||
return 0, fmt.Errorf("chunk size %d out of range (max %d)", size, icapMaxRespSize)
|
||||
}
|
||||
|
||||
cr.remaining = int(size)
|
||||
}
|
||||
|
||||
toRead := len(p)
|
||||
if toRead > cr.remaining {
|
||||
toRead = cr.remaining
|
||||
}
|
||||
|
||||
n, err := cr.r.Read(p[:toRead])
|
||||
cr.remaining -= n
|
||||
|
||||
if cr.remaining == 0 {
|
||||
// Consume chunk-terminating \r\n
|
||||
_, _ = cr.r.ReadString('\n')
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
21
client/inspect/listener.go
Normal file
21
client/inspect/listener.go
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build !linux
|
||||
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// newTPROXYListener is not supported on non-Linux platforms.
|
||||
func newTPROXYListener(_ *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
|
||||
return nil, fmt.Errorf("TPROXY listener not supported on this platform (requested %s)", addr)
|
||||
}
|
||||
|
||||
// getOriginalDst is not supported on non-Linux platforms.
|
||||
func getOriginalDst(_ net.Conn) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, fmt.Errorf("SO_ORIGINAL_DST not supported on this platform")
|
||||
}
|
||||
89
client/inspect/listener_linux.go
Normal file
89
client/inspect/listener_linux.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// newTPROXYListener creates a TCP listener for the transparent proxy.
|
||||
// After nftables REDIRECT, accepted connections have LocalAddr = WG_IP:proxy_port.
|
||||
// The original destination is retrieved via getsockopt(SO_ORIGINAL_DST).
|
||||
func newTPROXYListener(logger *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
|
||||
ln, err := net.Listen("tcp", addr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
logger.Infof("inspect: listener started on %s", ln.Addr())
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
// getOriginalDst reads the original destination from conntrack via SO_ORIGINAL_DST.
|
||||
// This is set by the kernel when the connection was REDIRECT'd/DNAT'd.
|
||||
// Tries IPv4 first, then falls back to IPv6 (IP6T_SO_ORIGINAL_DST).
|
||||
func getOriginalDst(conn net.Conn) (netip.AddrPort, error) {
|
||||
tc, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return netip.AddrPort{}, fmt.Errorf("not a TCPConn")
|
||||
}
|
||||
|
||||
raw, err := tc.SyscallConn()
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("get syscall conn: %w", err)
|
||||
}
|
||||
|
||||
var origDst netip.AddrPort
|
||||
var sockErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
// Try IPv4 first (SO_ORIGINAL_DST = 80)
|
||||
var sa4 unix.RawSockaddrInet4
|
||||
sa4Len := uint32(unsafe.Sizeof(sa4))
|
||||
_, _, errno := unix.Syscall6(
|
||||
unix.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
unix.SOL_IP,
|
||||
80, // SO_ORIGINAL_DST
|
||||
uintptr(unsafe.Pointer(&sa4)),
|
||||
uintptr(unsafe.Pointer(&sa4Len)),
|
||||
0,
|
||||
)
|
||||
if errno == 0 {
|
||||
addr := netip.AddrFrom4(sa4.Addr)
|
||||
port := uint16(sa4.Port>>8) | uint16(sa4.Port<<8)
|
||||
origDst = netip.AddrPortFrom(addr.Unmap(), port)
|
||||
return
|
||||
}
|
||||
|
||||
// Fall back to IPv6 (IP6T_SO_ORIGINAL_DST = 80 on SOL_IPV6)
|
||||
var sa6 unix.RawSockaddrInet6
|
||||
sa6Len := uint32(unsafe.Sizeof(sa6))
|
||||
_, _, errno = unix.Syscall6(
|
||||
unix.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
unix.SOL_IPV6,
|
||||
80, // IP6T_SO_ORIGINAL_DST
|
||||
uintptr(unsafe.Pointer(&sa6)),
|
||||
uintptr(unsafe.Pointer(&sa6Len)),
|
||||
0,
|
||||
)
|
||||
if errno != 0 {
|
||||
sockErr = fmt.Errorf("getsockopt SO_ORIGINAL_DST (v4 and v6): %w", errno)
|
||||
return
|
||||
}
|
||||
addr := netip.AddrFrom16(sa6.Addr)
|
||||
port := uint16(sa6.Port>>8) | uint16(sa6.Port<<8)
|
||||
origDst = netip.AddrPortFrom(addr.Unmap(), port)
|
||||
}); err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("control raw conn: %w", err)
|
||||
}
|
||||
if sockErr != nil {
|
||||
return netip.AddrPort{}, sockErr
|
||||
}
|
||||
|
||||
return origDst, nil
|
||||
}
|
||||
200
client/inspect/mitm.go
Normal file
200
client/inspect/mitm.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
mrand "math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// certCacheSize is the maximum number of cached leaf certificates.
|
||||
certCacheSize = 1024
|
||||
// certTTL is how long generated certificates remain valid.
|
||||
certTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// certCache is a bounded LRU cache for generated TLS certificates.
|
||||
type certCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*certEntry
|
||||
// order tracks LRU eviction, most recent at end.
|
||||
order []string
|
||||
maxSize int
|
||||
}
|
||||
|
||||
type certEntry struct {
|
||||
cert *tls.Certificate
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newCertCache(maxSize int) *certCache {
|
||||
return &certCache{
|
||||
entries: make(map[string]*certEntry, maxSize),
|
||||
order: make([]string, 0, maxSize),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certCache) get(hostname string) (*tls.Certificate, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
entry, ok := c.entries[hostname]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
c.removeLocked(hostname)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Move to end (most recently used)
|
||||
c.touchLocked(hostname)
|
||||
return entry.cert, true
|
||||
}
|
||||
|
||||
func (c *certCache) put(hostname string, cert *tls.Certificate) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Jitter the TTL by +/- 20% to prevent thundering herd on expiry.
|
||||
jitter := time.Duration(float64(certTTL) * (0.8 + 0.4*mrand.Float64()))
|
||||
|
||||
if _, exists := c.entries[hostname]; exists {
|
||||
c.entries[hostname] = &certEntry{
|
||||
cert: cert,
|
||||
expiresAt: time.Now().Add(jitter),
|
||||
}
|
||||
c.touchLocked(hostname)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity
|
||||
for len(c.entries) >= c.maxSize && len(c.order) > 0 {
|
||||
c.removeLocked(c.order[0])
|
||||
}
|
||||
|
||||
c.entries[hostname] = &certEntry{
|
||||
cert: cert,
|
||||
expiresAt: time.Now().Add(jitter),
|
||||
}
|
||||
c.order = append(c.order, hostname)
|
||||
}
|
||||
|
||||
func (c *certCache) touchLocked(hostname string) {
|
||||
for i, h := range c.order {
|
||||
if h == hostname {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
c.order = append(c.order, hostname)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certCache) removeLocked(hostname string) {
|
||||
delete(c.entries, hostname)
|
||||
for i, h := range c.order {
|
||||
if h == hostname {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CertProvider generates TLS certificates on the fly, signed by a CA.
|
||||
// Generated certificates are cached in an LRU cache.
|
||||
type CertProvider struct {
|
||||
ca *x509.Certificate
|
||||
caKey crypto.PrivateKey
|
||||
cache *certCache
|
||||
}
|
||||
|
||||
// NewCertProvider creates a certificate provider using the given CA.
|
||||
func NewCertProvider(ca *x509.Certificate, caKey crypto.PrivateKey) *CertProvider {
|
||||
return &CertProvider{
|
||||
ca: ca,
|
||||
caKey: caKey,
|
||||
cache: newCertCache(certCacheSize),
|
||||
}
|
||||
}
|
||||
|
||||
// GetCertificate returns a TLS certificate for the given hostname,
|
||||
// generating and caching one if necessary.
|
||||
func (p *CertProvider) GetCertificate(hostname string) (*tls.Certificate, error) {
|
||||
if cert, ok := p.cache.get(hostname); ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
cert, err := p.generateCert(hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate cert for %s: %w", hostname, err)
|
||||
}
|
||||
|
||||
p.cache.put(hostname, cert)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// GetTLSConfig returns a tls.Config that dynamically provides certificates
|
||||
// for any hostname using the MITM CA.
|
||||
func (p *CertProvider) GetTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return p.GetCertificate(hello.ServerName)
|
||||
},
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CertProvider) generateCert(hostname string) (*tls.Certificate, error) {
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: hostname,
|
||||
},
|
||||
NotBefore: now.Add(-5 * time.Minute),
|
||||
NotAfter: now.Add(certTTL),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
DNSNames: []string{hostname},
|
||||
}
|
||||
|
||||
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate leaf key: %w", err)
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, p.ca, &leafKey.PublicKey, p.caKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign leaf certificate: %w", err)
|
||||
}
|
||||
|
||||
leafCert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse generated certificate: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certDER, p.ca.Raw},
|
||||
PrivateKey: leafKey,
|
||||
Leaf: leafCert,
|
||||
}, nil
|
||||
}
|
||||
133
client/inspect/mitm_test.go
Normal file
133
client/inspect/mitm_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Test CA",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
return cert, key
|
||||
}
|
||||
|
||||
func TestCertProvider_GetCertificate(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
cert, err := provider.GetCertificate("example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
|
||||
// Verify the leaf certificate
|
||||
assert.Equal(t, "example.com", cert.Leaf.Subject.CommonName)
|
||||
assert.Contains(t, cert.Leaf.DNSNames, "example.com")
|
||||
|
||||
// Verify chain: leaf + CA
|
||||
assert.Len(t, cert.Certificate, 2)
|
||||
|
||||
// Verify leaf is signed by our CA
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(ca)
|
||||
_, err = cert.Leaf.Verify(x509.VerifyOptions{
|
||||
Roots: pool,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertProvider_CachesResults(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
cert1, err := provider.GetCertificate("cached.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
cert2, err := provider.GetCertificate("cached.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Same pointer = cached
|
||||
assert.Equal(t, cert1, cert2)
|
||||
}
|
||||
|
||||
func TestCertProvider_DifferentHostsDifferentCerts(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
cert1, err := provider.GetCertificate("a.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
cert2, err := provider.GetCertificate("b.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, cert1.Leaf.SerialNumber, cert2.Leaf.SerialNumber)
|
||||
}
|
||||
|
||||
func TestCertProvider_TLSConfigHandshake(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
tlsConfig := provider.GetTLSConfig()
|
||||
require.NotNil(t, tlsConfig)
|
||||
require.NotNil(t, tlsConfig.GetCertificate)
|
||||
|
||||
// Simulate a ClientHelloInfo
|
||||
hello := &tls.ClientHelloInfo{
|
||||
ServerName: "handshake.example.com",
|
||||
}
|
||||
|
||||
cert, err := tlsConfig.GetCertificate(hello)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "handshake.example.com", cert.Leaf.Subject.CommonName)
|
||||
}
|
||||
|
||||
func TestCertCache_Eviction(t *testing.T) {
|
||||
cache := newCertCache(3)
|
||||
|
||||
for i := range 5 {
|
||||
hostname := string(rune('a'+i)) + ".example.com"
|
||||
cache.put(hostname, &tls.Certificate{})
|
||||
}
|
||||
|
||||
// Only 3 should remain (c, d, e - the most recent)
|
||||
assert.Len(t, cache.entries, 3)
|
||||
|
||||
_, ok := cache.get("a.example.com")
|
||||
assert.False(t, ok, "oldest entry should be evicted")
|
||||
|
||||
_, ok = cache.get("b.example.com")
|
||||
assert.False(t, ok, "second oldest should be evicted")
|
||||
|
||||
_, ok = cache.get("e.example.com")
|
||||
assert.True(t, ok, "newest entry should exist")
|
||||
}
|
||||
109
client/inspect/peek.go
Normal file
109
client/inspect/peek.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// peekConn wraps a net.Conn with a buffer that allows reading ahead
|
||||
// without consuming data. Subsequent Read calls return the buffered
|
||||
// bytes first, then read from the underlying connection.
|
||||
type peekConn struct {
|
||||
net.Conn
|
||||
buf bytes.Buffer
|
||||
// peeked holds the raw bytes that were peeked, available for replay.
|
||||
peeked []byte
|
||||
}
|
||||
|
||||
// newPeekConn wraps conn for peek-ahead reading.
|
||||
func newPeekConn(conn net.Conn) *peekConn {
|
||||
return &peekConn{Conn: conn}
|
||||
}
|
||||
|
||||
// Peek reads exactly n bytes from the connection without consuming them.
|
||||
// The peeked bytes are replayed on subsequent Read calls.
|
||||
// Peek may only be called once; calling it again returns an error.
|
||||
func (c *peekConn) Peek(n int) ([]byte, error) {
|
||||
if c.peeked != nil {
|
||||
return nil, fmt.Errorf("peek already called")
|
||||
}
|
||||
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(c.Conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("peek %d bytes: %w", n, err)
|
||||
}
|
||||
|
||||
c.peeked = buf
|
||||
c.buf.Write(buf)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// PeekAll reads up to n bytes, returning whatever is available.
|
||||
// Unlike Peek, it does not require exactly n bytes.
|
||||
func (c *peekConn) PeekAll(n int) ([]byte, error) {
|
||||
if c.peeked != nil {
|
||||
return nil, fmt.Errorf("peek already called")
|
||||
}
|
||||
|
||||
buf := make([]byte, n)
|
||||
nr, err := c.Conn.Read(buf)
|
||||
if nr > 0 {
|
||||
c.peeked = buf[:nr]
|
||||
c.buf.Write(c.peeked)
|
||||
}
|
||||
if err != nil && nr == 0 {
|
||||
return nil, fmt.Errorf("peek: %w", err)
|
||||
}
|
||||
|
||||
return c.peeked, nil
|
||||
}
|
||||
|
||||
// PeekMore extends the peeked buffer to at least n total bytes.
|
||||
// The buffer is reset and refilled with the extended data.
|
||||
// The returned slice is the internal peeked buffer; callers must not
|
||||
// retain references from prior Peek/PeekMore calls after calling this.
|
||||
func (c *peekConn) PeekMore(n int) ([]byte, error) {
|
||||
if len(c.peeked) >= n {
|
||||
return c.peeked[:n], nil
|
||||
}
|
||||
|
||||
remaining := n - len(c.peeked)
|
||||
extra := make([]byte, remaining)
|
||||
if _, err := io.ReadFull(c.Conn, extra); err != nil {
|
||||
return nil, fmt.Errorf("peek more %d bytes: %w", remaining, err)
|
||||
}
|
||||
|
||||
// Pre-allocate to avoid reallocation detaching previously returned slices.
|
||||
combined := make([]byte, 0, n)
|
||||
combined = append(combined, c.peeked...)
|
||||
combined = append(combined, extra...)
|
||||
c.peeked = combined
|
||||
c.buf.Reset()
|
||||
c.buf.Write(c.peeked)
|
||||
|
||||
return c.peeked, nil
|
||||
}
|
||||
|
||||
// Peeked returns the bytes that were peeked so far, or nil if Peek hasn't been called.
|
||||
func (c *peekConn) Peeked() []byte {
|
||||
return c.peeked
|
||||
}
|
||||
|
||||
// Read returns buffered peek data first, then reads from the underlying connection.
|
||||
func (c *peekConn) Read(p []byte) (int, error) {
|
||||
if c.buf.Len() > 0 {
|
||||
return c.buf.Read(p)
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
// reader returns an io.Reader that replays buffered bytes then reads from conn.
|
||||
func (c *peekConn) reader() io.Reader {
|
||||
if c.buf.Len() > 0 {
|
||||
return io.MultiReader(&c.buf, c.Conn)
|
||||
}
|
||||
return c.Conn
|
||||
}
|
||||
482
client/inspect/proxy.go
Normal file
482
client/inspect/proxy.go
Normal file
@@ -0,0 +1,482 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ErrBlocked is returned when a connection is denied by proxy policy.
|
||||
var ErrBlocked = errors.New("connection blocked by proxy policy")
|
||||
|
||||
const (
|
||||
// headerReadTimeout is the deadline for reading the initial protocol header.
|
||||
// Prevents slow loris attacks where a client opens a connection but sends data slowly.
|
||||
headerReadTimeout = 10 * time.Second
|
||||
|
||||
// idleTimeout is the deadline for idle connections between HTTP requests.
|
||||
idleTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
// Proxy is the inspection engine for traffic passing through a NetBird
|
||||
// routing peer. It handles protocol detection, rule evaluation, MITM TLS
|
||||
// decryption, ICAP delegation, and external proxy forwarding.
|
||||
type Proxy struct {
|
||||
config Config
|
||||
rules *RuleEngine
|
||||
certs *CertProvider
|
||||
icap *ICAPClient
|
||||
// envoy is nil unless mode is ModeEnvoy.
|
||||
envoy *envoyManager
|
||||
// dialer is the outbound dialer (with SO_MARK cleared on Linux).
|
||||
dialer net.Dialer
|
||||
log *log.Entry
|
||||
// wgNetwork is the WG overlay prefix; dial targets inside it are blocked.
|
||||
wgNetwork netip.Prefix
|
||||
// localIPs reports the routing peer's own IPs; dial targets are blocked.
|
||||
localIPs LocalIPChecker
|
||||
// listener is the TPROXY/REDIRECT listener for kernel mode.
|
||||
listener net.Listener
|
||||
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// LocalIPChecker reports whether an IP belongs to the local machine.
|
||||
type LocalIPChecker interface {
|
||||
IsLocalIP(netip.Addr) bool
|
||||
}
|
||||
|
||||
// New creates a transparent proxy with the given configuration.
|
||||
func New(ctx context.Context, logger *log.Entry, config Config) (*Proxy, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
p := &Proxy{
|
||||
config: config,
|
||||
rules: NewRuleEngine(logger, config.DefaultAction),
|
||||
dialer: newOutboundDialer(),
|
||||
log: logger,
|
||||
wgNetwork: config.WGNetwork,
|
||||
localIPs: config.LocalIPChecker,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.rules.UpdateRules(config.Rules, config.DefaultAction)
|
||||
|
||||
// Initialize MITM certificate provider
|
||||
if config.TLS != nil {
|
||||
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
|
||||
}
|
||||
|
||||
// Initialize ICAP client
|
||||
if config.ICAP != nil {
|
||||
p.icap = NewICAPClient(logger, config.ICAP)
|
||||
}
|
||||
|
||||
// Start envoy sidecar if configured
|
||||
if config.Mode == ModeEnvoy {
|
||||
envoyLog := logger.WithField("sidecar", "envoy")
|
||||
em, err := startEnvoy(ctx, envoyLog, config)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("start envoy sidecar: %w", err)
|
||||
}
|
||||
p.envoy = em
|
||||
}
|
||||
|
||||
// Start TPROXY listener for kernel mode
|
||||
if config.ListenAddr.IsValid() {
|
||||
ln, err := newTPROXYListener(logger, config.ListenAddr, netip.Prefix{})
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("start TPROXY listener on %s: %w", config.ListenAddr, err)
|
||||
}
|
||||
p.listener = ln
|
||||
go p.acceptLoop(ln)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// HandleTCP is the entry point for TCP connections from the userspace forwarder.
|
||||
// It determines the protocol (TLS or plaintext HTTP), evaluates rules,
|
||||
// and either blocks, passes through, inspects, or forwards to an external proxy.
|
||||
func (p *Proxy) HandleTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) error {
|
||||
defer func() {
|
||||
if err := clientConn.Close(); err != nil {
|
||||
p.log.Debugf("close client conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
p.mu.RLock()
|
||||
mode := p.config.Mode
|
||||
p.mu.RUnlock()
|
||||
|
||||
if mode == ModeExternal {
|
||||
pconn := newPeekConn(clientConn)
|
||||
return p.handleExternal(ctx, pconn, dst)
|
||||
}
|
||||
|
||||
// Envoy and builtin modes both peek the protocol header for rule evaluation.
|
||||
// Envoy mode forwards non-blocked traffic to envoy; builtin mode handles all locally.
|
||||
// TLS blocks are handled by Go (instant close) since envoy can't cleanly RST a TLS connection.
|
||||
|
||||
// Built-in and envoy mode: peek 5 bytes (TLS record header size) to determine protocol.
|
||||
// Set a read deadline to prevent slow loris attacks.
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
pconn := newPeekConn(clientConn)
|
||||
header, err := pconn.Peek(5)
|
||||
if err != nil {
|
||||
return fmt.Errorf("peek protocol header: %w", err)
|
||||
}
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
if isTLSHandshake(header[0]) {
|
||||
return p.handleTLS(ctx, pconn, dst, src)
|
||||
}
|
||||
|
||||
if isHTTPMethod(header) {
|
||||
return p.handlePlainHTTP(ctx, pconn, dst, src)
|
||||
}
|
||||
|
||||
// Not TLS and not HTTP: evaluate rules with ProtoOther.
|
||||
// If no rule explicitly allows "other", this falls through to the default action.
|
||||
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
|
||||
if action == ActionAllow {
|
||||
remote, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial for passthrough: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remote.Close(); err != nil {
|
||||
p.log.Debugf("close remote conn: %v", err)
|
||||
}
|
||||
}()
|
||||
return relay(ctx, pconn, remote)
|
||||
}
|
||||
|
||||
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
|
||||
return ErrBlocked
|
||||
}
|
||||
|
||||
// InspectTCP evaluates rules for a TCP connection and returns the result.
|
||||
// Unlike HandleTCP, it can return early for allow decisions, letting the caller
|
||||
// handle the relay (USP forwarder passthrough optimization).
|
||||
//
|
||||
// When InspectResult.PassthroughConn is non-nil, ownership transfers to the caller:
|
||||
// the caller must close the connection and relay traffic. The engine does not close it.
|
||||
//
|
||||
// When PassthroughConn is nil, the engine handled everything internally
|
||||
// (block, inspect/MITM, or plain HTTP inspection) and closed the connection.
|
||||
func (p *Proxy) InspectTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
|
||||
p.mu.RLock()
|
||||
mode := p.config.Mode
|
||||
envoy := p.envoy
|
||||
p.mu.RUnlock()
|
||||
|
||||
// External mode: handle internally, engine owns the connection.
|
||||
if mode == ModeExternal {
|
||||
defer func() {
|
||||
if err := clientConn.Close(); err != nil {
|
||||
p.log.Debugf("close client conn: %v", err)
|
||||
}
|
||||
}()
|
||||
pconn := newPeekConn(clientConn)
|
||||
err := p.handleExternal(ctx, pconn, dst)
|
||||
return InspectResult{Action: ActionAllow}, err
|
||||
}
|
||||
|
||||
// Peek protocol header.
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
|
||||
clientConn.Close()
|
||||
return InspectResult{}, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
pconn := newPeekConn(clientConn)
|
||||
header, err := pconn.Peek(5)
|
||||
if err != nil {
|
||||
clientConn.Close()
|
||||
return InspectResult{}, fmt.Errorf("peek protocol header: %w", err)
|
||||
}
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
clientConn.Close()
|
||||
return InspectResult{}, fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
// TLS: may return passthrough for allow.
|
||||
if isTLSHandshake(header[0]) {
|
||||
result, err := p.inspectTLS(ctx, pconn, dst, src)
|
||||
if err != nil && result.PassthroughConn == nil {
|
||||
clientConn.Close()
|
||||
return result, err
|
||||
}
|
||||
// Envoy mode: forward allowed TLS to envoy instead of returning passthrough.
|
||||
if result.PassthroughConn != nil && envoy != nil {
|
||||
defer clientConn.Close()
|
||||
envoyErr := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
return InspectResult{Action: ActionAllow}, envoyErr
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Plain HTTP: in envoy mode, forward to envoy for L7 processing.
|
||||
// In builtin mode, inspect per-request locally.
|
||||
if isHTTPMethod(header) {
|
||||
defer func() {
|
||||
if err := clientConn.Close(); err != nil {
|
||||
p.log.Debugf("close client conn: %v", err)
|
||||
}
|
||||
}()
|
||||
if envoy != nil {
|
||||
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
return InspectResult{Action: ActionAllow}, err
|
||||
}
|
||||
err := p.handlePlainHTTP(ctx, pconn, dst, src)
|
||||
return InspectResult{Action: ActionInspect}, err
|
||||
}
|
||||
|
||||
// Other protocol: evaluate rules.
|
||||
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
|
||||
if action == ActionAllow {
|
||||
// Envoy mode: forward to envoy.
|
||||
if envoy != nil {
|
||||
defer clientConn.Close()
|
||||
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
return InspectResult{Action: ActionAllow}, err
|
||||
}
|
||||
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
|
||||
}
|
||||
|
||||
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
|
||||
clientConn.Close()
|
||||
return InspectResult{Action: ActionBlock}, ErrBlocked
|
||||
}
|
||||
|
||||
// HandleUDPPacket inspects a UDP packet for QUIC Initial packets.
|
||||
// Returns the action to take: ActionAllow to continue normal forwarding,
|
||||
// ActionBlock to drop the packet.
|
||||
// Non-QUIC packets always return ActionAllow.
|
||||
func (p *Proxy) HandleUDPPacket(data []byte, dst netip.AddrPort, src SourceInfo) Action {
|
||||
if len(data) < 5 {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
// Check for QUIC Long Header
|
||||
if data[0]&0x80 == 0 {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
sni, err := ExtractQUICSNI(data)
|
||||
if err != nil {
|
||||
// Can't parse QUIC, allow through (could be non-QUIC UDP)
|
||||
p.log.Tracef("QUIC SNI extraction failed for %s: %v", dst, err)
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
if sni == "" {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
action := p.rules.Evaluate(src.IP, sni, dst.Addr(), dst.Port(), ProtoH3, "")
|
||||
|
||||
if action == ActionBlock {
|
||||
p.log.Debugf("block: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
return ActionBlock
|
||||
}
|
||||
|
||||
// QUIC can't be MITMed, treat Inspect as Allow
|
||||
if action == ActionInspect {
|
||||
p.log.Debugf("allow: QUIC to %s (SNI=%s), MITM not supported for QUIC", dst, sni.PunycodeString())
|
||||
} else {
|
||||
p.log.Tracef("allow: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
}
|
||||
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
// handlePlainHTTP handles plaintext HTTP connections.
|
||||
func (p *Proxy) handlePlainHTTP(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
|
||||
remote, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", dst, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remote.Close(); err != nil {
|
||||
p.log.Debugf("close remote for %s: %v", dst, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// For plaintext HTTP, always inspect (we can see the traffic)
|
||||
return p.inspectHTTP(ctx, pconn, remote, dst, "", src, "http/1.1")
|
||||
}
|
||||
|
||||
// UpdateConfig replaces the inspection engine configuration at runtime.
|
||||
func (p *Proxy) UpdateConfig(config Config) {
|
||||
p.log.Debugf("config update: mode=%s rules=%d default=%s has_tls=%v has_icap=%v",
|
||||
config.Mode, len(config.Rules), config.DefaultAction, config.TLS != nil, config.ICAP != nil)
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
p.config = config
|
||||
p.rules.UpdateRules(config.Rules, config.DefaultAction)
|
||||
|
||||
// Update MITM provider
|
||||
if config.TLS != nil {
|
||||
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
|
||||
} else {
|
||||
p.certs = nil
|
||||
}
|
||||
|
||||
// Swap ICAP client under lock, close the old one outside to avoid blocking.
|
||||
var oldICAP *ICAPClient
|
||||
if config.ICAP != nil {
|
||||
oldICAP = p.icap
|
||||
p.icap = NewICAPClient(p.log, config.ICAP)
|
||||
} else {
|
||||
oldICAP = p.icap
|
||||
p.icap = nil
|
||||
}
|
||||
|
||||
// If switching away from envoy mode, clear and stop the old envoy.
|
||||
var oldEnvoy *envoyManager
|
||||
if config.Mode != ModeEnvoy && p.envoy != nil {
|
||||
oldEnvoy = p.envoy
|
||||
p.envoy = nil
|
||||
}
|
||||
|
||||
envoy := p.envoy
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
if oldICAP != nil {
|
||||
oldICAP.Close()
|
||||
}
|
||||
|
||||
if oldEnvoy != nil {
|
||||
oldEnvoy.Stop()
|
||||
}
|
||||
|
||||
// Reload envoy config if still in envoy mode.
|
||||
if envoy != nil && config.Mode == ModeEnvoy {
|
||||
if err := envoy.Reload(config); err != nil {
|
||||
p.log.Errorf("inspect: envoy config reload: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mode returns the current proxy operating mode.
|
||||
func (p *Proxy) Mode() ProxyMode {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.config.Mode
|
||||
}
|
||||
|
||||
// ListenPort returns the port to use for kernel-mode nftables REDIRECT.
|
||||
// For builtin mode: the TPROXY listener port.
|
||||
// For envoy mode: the envoy listener port (nftables redirects directly to envoy).
|
||||
// Returns 0 if no listener is active.
|
||||
func (p *Proxy) ListenPort() uint16 {
|
||||
p.mu.RLock()
|
||||
envoy := p.envoy
|
||||
p.mu.RUnlock()
|
||||
|
||||
if envoy != nil {
|
||||
return envoy.listenPort
|
||||
}
|
||||
if p.listener == nil {
|
||||
return 0
|
||||
}
|
||||
tcpAddr, ok := p.listener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return uint16(tcpAddr.Port)
|
||||
}
|
||||
|
||||
// Close shuts down the proxy and releases resources.
|
||||
func (p *Proxy) Close() error {
|
||||
p.cancel()
|
||||
|
||||
p.mu.Lock()
|
||||
envoy := p.envoy
|
||||
p.envoy = nil
|
||||
icap := p.icap
|
||||
p.icap = nil
|
||||
p.mu.Unlock()
|
||||
|
||||
if envoy != nil {
|
||||
envoy.Stop()
|
||||
}
|
||||
|
||||
if p.listener != nil {
|
||||
if err := p.listener.Close(); err != nil {
|
||||
p.log.Debugf("close TPROXY listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if icap != nil {
|
||||
icap.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts connections from the redirected listener (kernel mode).
|
||||
// Connections arrive via nftables REDIRECT; original destination is read from conntrack.
|
||||
func (p *Proxy) acceptLoop(ln net.Listener) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
p.log.Debugf("accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Read original destination from conntrack (SO_ORIGINAL_DST).
|
||||
// nftables REDIRECT changes dst to the local WG IP:proxy_port,
|
||||
// but conntrack preserves the real destination.
|
||||
dstAddr, err := getOriginalDst(conn)
|
||||
if err != nil {
|
||||
p.log.Debugf("get original dst: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
p.log.Debugf("close conn: %v", closeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
p.log.Tracef("accepted: %s -> %s (original dst %s)",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), dstAddr)
|
||||
|
||||
srcAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
p.log.Debugf("parse source: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
p.log.Debugf("close conn: %v", closeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
src := SourceInfo{
|
||||
IP: srcAddr.Addr().Unmap(),
|
||||
}
|
||||
|
||||
if err := p.HandleTCP(p.ctx, conn, dstAddr, src); err != nil && !errors.Is(err, ErrBlocked) {
|
||||
p.log.Debugf("connection to %s: %v", dstAddr, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
388
client/inspect/quic.go
Normal file
388
client/inspect/quic.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// QUIC version constants
|
||||
const (
|
||||
quicV1Version uint32 = 0x00000001
|
||||
quicV2Version uint32 = 0x6b3343cf
|
||||
)
|
||||
|
||||
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
|
||||
var quicV1Salt = []byte{
|
||||
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
|
||||
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
|
||||
0xcc, 0xbb, 0x7f, 0x0a,
|
||||
}
|
||||
|
||||
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
|
||||
var quicV2Salt = []byte{
|
||||
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
|
||||
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
|
||||
0xf9, 0xbd, 0x2e, 0xd9,
|
||||
}
|
||||
|
||||
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
|
||||
// The Initial packet's encryption uses well-known keys derived from the
|
||||
// Destination Connection ID, so any observer can decrypt it (by design).
|
||||
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
|
||||
if len(data) < 5 {
|
||||
return "", fmt.Errorf("packet too short")
|
||||
}
|
||||
|
||||
// Check for QUIC Long Header (form bit set)
|
||||
if data[0]&0x80 == 0 {
|
||||
return "", fmt.Errorf("not a QUIC long header packet")
|
||||
}
|
||||
|
||||
// Version
|
||||
version := binary.BigEndian.Uint32(data[1:5])
|
||||
|
||||
var salt []byte
|
||||
var initialLabel, keyLabel, ivLabel, hpLabel string
|
||||
|
||||
switch version {
|
||||
case quicV1Version:
|
||||
salt = quicV1Salt
|
||||
initialLabel = "client in"
|
||||
keyLabel = "quic key"
|
||||
ivLabel = "quic iv"
|
||||
hpLabel = "quic hp"
|
||||
case quicV2Version:
|
||||
salt = quicV2Salt
|
||||
initialLabel = "client in"
|
||||
keyLabel = "quicv2 key"
|
||||
ivLabel = "quicv2 iv"
|
||||
hpLabel = "quicv2 hp"
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
|
||||
}
|
||||
|
||||
// Parse Long Header
|
||||
if len(data) < 6 {
|
||||
return "", fmt.Errorf("packet too short for DCID length")
|
||||
}
|
||||
dcidLen := int(data[5])
|
||||
if len(data) < 6+dcidLen+1 {
|
||||
return "", fmt.Errorf("packet too short for DCID")
|
||||
}
|
||||
dcid := data[6 : 6+dcidLen]
|
||||
|
||||
scidLenOff := 6 + dcidLen
|
||||
scidLen := int(data[scidLenOff])
|
||||
tokenLenOff := scidLenOff + 1 + scidLen
|
||||
|
||||
if tokenLenOff >= len(data) {
|
||||
return "", fmt.Errorf("packet too short for token length")
|
||||
}
|
||||
|
||||
// Token length is a variable-length integer
|
||||
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read token length: %w", err)
|
||||
}
|
||||
|
||||
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
|
||||
if payloadLenOff >= len(data) {
|
||||
return "", fmt.Errorf("packet too short for payload length")
|
||||
}
|
||||
|
||||
// Payload length is a variable-length integer
|
||||
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read payload length: %w", err)
|
||||
}
|
||||
|
||||
pnOffset := payloadLenOff + payloadLenSize
|
||||
if pnOffset+4 > len(data) {
|
||||
return "", fmt.Errorf("packet too short for packet number")
|
||||
}
|
||||
|
||||
// Derive initial keys
|
||||
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("derive initial keys: %w", err)
|
||||
}
|
||||
|
||||
// Remove header protection
|
||||
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
|
||||
if sampleOffset+16 > len(data) {
|
||||
return "", fmt.Errorf("packet too short for HP sample")
|
||||
}
|
||||
sample := data[sampleOffset : sampleOffset+16]
|
||||
|
||||
hpBlock, err := aes.NewCipher(clientHP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create HP cipher: %w", err)
|
||||
}
|
||||
|
||||
mask := make([]byte, 16)
|
||||
hpBlock.Encrypt(mask, sample)
|
||||
|
||||
// Unmask header byte
|
||||
header := make([]byte, len(data))
|
||||
copy(header, data)
|
||||
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
|
||||
|
||||
// Determine packet number length
|
||||
pnLen := int(header[0]&0x03) + 1
|
||||
|
||||
// Unmask packet number
|
||||
for i := 0; i < pnLen; i++ {
|
||||
header[pnOffset+i] ^= mask[1+i]
|
||||
}
|
||||
|
||||
// Reconstruct packet number
|
||||
var pn uint32
|
||||
for i := 0; i < pnLen; i++ {
|
||||
pn = (pn << 8) | uint32(header[pnOffset+i])
|
||||
}
|
||||
|
||||
// Build nonce
|
||||
nonce := make([]byte, len(clientIV))
|
||||
copy(nonce, clientIV)
|
||||
for i := 0; i < 4; i++ {
|
||||
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
|
||||
}
|
||||
|
||||
// Decrypt payload
|
||||
block, err := aes.NewCipher(clientKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create AEAD: %w", err)
|
||||
}
|
||||
|
||||
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
|
||||
aad := header[:pnOffset+pnLen]
|
||||
|
||||
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse CRYPTO frames to extract ClientHello
|
||||
clientHello, err := extractCryptoFrames(plaintext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
|
||||
}
|
||||
|
||||
info, err := parseHelloBody(clientHello)
|
||||
return info.SNI, err
|
||||
}
|
||||
|
||||
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
|
||||
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
|
||||
// initial_secret = HKDF-Extract(salt, DCID)
|
||||
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
|
||||
|
||||
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
|
||||
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
|
||||
}
|
||||
|
||||
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
|
||||
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
|
||||
}
|
||||
|
||||
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
|
||||
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
|
||||
}
|
||||
|
||||
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
|
||||
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
|
||||
}
|
||||
|
||||
return key, iv, hp, nil
|
||||
}
|
||||
|
||||
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
|
||||
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
|
||||
// HkdfLabel = struct {
|
||||
// uint16 length;
|
||||
// opaque label<7..255> = "tls13 " + Label;
|
||||
// opaque context<0..255> = Context;
|
||||
// }
|
||||
fullLabel := "tls13 " + label
|
||||
|
||||
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
|
||||
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
|
||||
hkdfLabel[2] = byte(len(fullLabel))
|
||||
copy(hkdfLabel[3:], fullLabel)
|
||||
hkdfLabel[3+len(fullLabel)] = byte(len(context))
|
||||
if len(context) > 0 {
|
||||
copy(hkdfLabel[4+len(fullLabel):], context)
|
||||
}
|
||||
|
||||
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
|
||||
out := make([]byte, length)
|
||||
if _, err := io.ReadFull(expander, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
|
||||
const maxCryptoFrameSize = 64 * 1024
|
||||
|
||||
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
|
||||
func extractCryptoFrames(frames []byte) ([]byte, error) {
|
||||
var result []byte
|
||||
pos := 0
|
||||
|
||||
for pos < len(frames) {
|
||||
frameType := frames[pos]
|
||||
|
||||
switch {
|
||||
case frameType == 0x00:
|
||||
// PADDING frame
|
||||
pos++
|
||||
|
||||
case frameType == 0x06:
|
||||
// CRYPTO frame
|
||||
pos++
|
||||
|
||||
offset, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read crypto offset: %w", err)
|
||||
}
|
||||
pos += n
|
||||
_ = offset // We assume ordered, offset 0 for Initial
|
||||
|
||||
dataLen, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read crypto data length: %w", err)
|
||||
}
|
||||
pos += n
|
||||
|
||||
end := pos + int(dataLen)
|
||||
if end > len(frames) {
|
||||
return nil, fmt.Errorf("CRYPTO frame data truncated")
|
||||
}
|
||||
|
||||
result = append(result, frames[pos:end]...)
|
||||
if len(result) > maxCryptoFrameSize {
|
||||
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
|
||||
}
|
||||
pos = end
|
||||
|
||||
case frameType == 0x01:
|
||||
// PING frame
|
||||
pos++
|
||||
|
||||
case frameType == 0x02 || frameType == 0x03:
|
||||
// ACK frame - skip
|
||||
pos++
|
||||
// Largest Acknowledged
|
||||
_, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// ACK Delay
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK delay: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// ACK Range Count
|
||||
rangeCount, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK range count: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// First ACK Range
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read first ACK range: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// Additional ranges
|
||||
for i := uint64(0); i < rangeCount; i++ {
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK gap: %w", err)
|
||||
}
|
||||
pos += n
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK range: %w", err)
|
||||
}
|
||||
pos += n
|
||||
}
|
||||
// ECN counts for type 0x03
|
||||
if frameType == 0x03 {
|
||||
for range 3 {
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ECN count: %w", err)
|
||||
}
|
||||
pos += n
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown frame type, stop parsing
|
||||
if len(result) > 0 {
|
||||
return result, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, fmt.Errorf("no CRYPTO frames found")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readVarInt reads a QUIC variable-length integer.
|
||||
// Returns (value, bytes consumed, error).
|
||||
func readVarInt(data []byte) (uint64, int, error) {
|
||||
if len(data) == 0 {
|
||||
return 0, 0, fmt.Errorf("empty data for varint")
|
||||
}
|
||||
|
||||
prefix := data[0] >> 6
|
||||
length := 1 << prefix
|
||||
|
||||
if len(data) < length {
|
||||
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
|
||||
}
|
||||
|
||||
var val uint64
|
||||
switch length {
|
||||
case 1:
|
||||
val = uint64(data[0] & 0x3f)
|
||||
case 2:
|
||||
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
|
||||
case 4:
|
||||
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
|
||||
case 8:
|
||||
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
|
||||
}
|
||||
|
||||
return val, length, nil
|
||||
}
|
||||
99
client/inspect/quic_test.go
Normal file
99
client/inspect/quic_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadVarInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
want uint64
|
||||
n int
|
||||
}{
|
||||
{
|
||||
name: "1 byte value",
|
||||
data: []byte{0x25},
|
||||
want: 37,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
name: "2 byte value",
|
||||
data: []byte{0x7b, 0xbd},
|
||||
want: 15293,
|
||||
n: 2,
|
||||
},
|
||||
{
|
||||
name: "4 byte value",
|
||||
data: []byte{0x9d, 0x7f, 0x3e, 0x7d},
|
||||
want: 494878333,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
name: "zero",
|
||||
data: []byte{0x00},
|
||||
want: 0,
|
||||
n: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, n, err := readVarInt(tt.data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, val)
|
||||
assert.Equal(t, tt.n, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadVarInt_Empty(t *testing.T) {
|
||||
_, _, err := readVarInt(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadVarInt_Truncated(t *testing.T) {
|
||||
// 2-byte prefix but only 1 byte
|
||||
_, _, err := readVarInt([]byte{0x40})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExtractQUICSNI_NotLongHeader(t *testing.T) {
|
||||
// Short header packet (form bit not set)
|
||||
data := make([]byte, 100)
|
||||
data[0] = 0x40 // short header
|
||||
|
||||
_, err := ExtractQUICSNI(data)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not a QUIC long header")
|
||||
}
|
||||
|
||||
func TestExtractQUICSNI_UnsupportedVersion(t *testing.T) {
|
||||
data := make([]byte, 100)
|
||||
data[0] = 0xC0 // long header
|
||||
// Version 0xdeadbeef
|
||||
data[1] = 0xde
|
||||
data[2] = 0xad
|
||||
data[3] = 0xbe
|
||||
data[4] = 0xef
|
||||
|
||||
_, err := ExtractQUICSNI(data)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported QUIC version")
|
||||
}
|
||||
|
||||
func TestExtractQUICSNI_TooShort(t *testing.T) {
|
||||
_, err := ExtractQUICSNI([]byte{0xC0, 0x00})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHkdfExpandLabel(t *testing.T) {
|
||||
// Smoke test: ensure it returns the right length and doesn't error
|
||||
secret := make([]byte, 32)
|
||||
result, err := hkdfExpandLabel(secret, "quic key", nil, 16)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 16)
|
||||
}
|
||||
253
client/inspect/rules.go
Normal file
253
client/inspect/rules.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// RuleEngine evaluates proxy rules against connection metadata.
|
||||
// It is safe for concurrent use.
|
||||
type RuleEngine struct {
|
||||
mu sync.RWMutex
|
||||
rules []Rule
|
||||
// defaultAction applies when no rule matches.
|
||||
defaultAction Action
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
// NewRuleEngine creates a rule engine with the given default action.
|
||||
func NewRuleEngine(logger *log.Entry, defaultAction Action) *RuleEngine {
|
||||
return &RuleEngine{
|
||||
defaultAction: defaultAction,
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateRules replaces the rule set and default action. Rules are sorted by priority.
|
||||
func (e *RuleEngine) UpdateRules(rules []Rule, defaultAction Action) {
|
||||
sorted := make([]Rule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
|
||||
e.mu.Lock()
|
||||
e.rules = sorted
|
||||
e.defaultAction = defaultAction
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// EvalResult holds the outcome of a rule evaluation.
|
||||
type EvalResult struct {
|
||||
Action Action
|
||||
RuleID id.RuleID
|
||||
}
|
||||
|
||||
// Evaluate determines the action for a connection based on the rule set.
|
||||
// Pass empty path for connection-level evaluation (TLS/SNI), non-empty for request-level (HTTP).
|
||||
func (e *RuleEngine) Evaluate(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) Action {
|
||||
r := e.EvaluateWithResult(src, dstDomain, dstAddr, dstPort, proto, path)
|
||||
return r.Action
|
||||
}
|
||||
|
||||
// EvaluateWithResult is like Evaluate but also returns the matched rule ID.
|
||||
func (e *RuleEngine) EvaluateWithResult(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) EvalResult {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
for i := range e.rules {
|
||||
rule := &e.rules[i]
|
||||
if e.ruleMatches(rule, src, dstDomain, dstAddr, dstPort, proto, path) {
|
||||
e.log.Tracef("rule %s matched: action=%s src=%s domain=%s dst=%s:%d proto=%s path=%s",
|
||||
rule.ID, rule.Action, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
|
||||
return EvalResult{Action: rule.Action, RuleID: rule.ID}
|
||||
}
|
||||
}
|
||||
|
||||
e.log.Tracef("no rule matched, default=%s: src=%s domain=%s dst=%s:%d proto=%s path=%s",
|
||||
e.defaultAction, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
|
||||
return EvalResult{Action: e.defaultAction}
|
||||
}
|
||||
|
||||
// HasPathRulesForDomain returns true if any rule matching the domain has non-empty Paths.
|
||||
// Used to force MITM inspection when path-level rules exist (paths are only visible after decryption).
|
||||
func (e *RuleEngine) HasPathRulesForDomain(dstDomain domain.Domain) bool {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
for i := range e.rules {
|
||||
if len(e.rules[i].Paths) > 0 && e.matchDomain(&e.rules[i], dstDomain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches checks whether all non-empty fields of a rule match.
|
||||
// Empty fields are treated as "match any".
|
||||
// All specified fields must match (AND logic).
|
||||
func (e *RuleEngine) ruleMatches(rule *Rule, src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) bool {
|
||||
if !e.matchSource(rule, src) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchDomain(rule, dstDomain) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchNetwork(rule, dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchPort(rule, dstPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchProtocol(rule, proto) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchPaths(rule, path) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// matchSource returns true if src matches any of the rule's source CIDRs,
|
||||
// or if no source CIDRs are specified (match any).
|
||||
func (e *RuleEngine) matchSource(rule *Rule, src netip.Addr) bool {
|
||||
if len(rule.Sources) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, prefix := range rule.Sources {
|
||||
if prefix.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchDomain returns true if dstDomain matches any of the rule's domain patterns,
|
||||
// or if no domain patterns are specified (match any).
|
||||
func (e *RuleEngine) matchDomain(rule *Rule, dstDomain domain.Domain) bool {
|
||||
if len(rule.Domains) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If we have domain rules but no domain to match against (e.g., raw IP connection),
|
||||
// the domain condition does not match.
|
||||
if dstDomain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range rule.Domains {
|
||||
if MatchDomain(pattern, dstDomain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchNetwork returns true if dstAddr is within any of the rule's destination CIDRs,
|
||||
// or if no destination CIDRs are specified (match any).
|
||||
func (e *RuleEngine) matchNetwork(rule *Rule, dstAddr netip.Addr) bool {
|
||||
if len(rule.Networks) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, prefix := range rule.Networks {
|
||||
if prefix.Contains(dstAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchProtocol returns true if proto matches any of the rule's protocols,
|
||||
// or if no protocols are specified (match any).
|
||||
func (e *RuleEngine) matchProtocol(rule *Rule, proto ProtoType) bool {
|
||||
if len(rule.Protocols) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, p := range rule.Protocols {
|
||||
if p == proto {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchPort returns true if dstPort matches any of the rule's destination ports,
|
||||
// or if no ports are specified (match any).
|
||||
func (e *RuleEngine) matchPort(rule *Rule, dstPort uint16) bool {
|
||||
if len(rule.Ports) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return slices.Contains(rule.Ports, dstPort)
|
||||
}
|
||||
|
||||
// matchPaths returns true if path matches any of the rule's path patterns,
|
||||
// or if no paths are specified (match any). Empty path (connection-level eval) matches all.
|
||||
func (e *RuleEngine) matchPaths(rule *Rule, path string) bool {
|
||||
if len(rule.Paths) == 0 {
|
||||
return true
|
||||
}
|
||||
// Connection-level (path=""): rules with paths don't match at connection level.
|
||||
// HasPathRulesForDomain forces the connection to inspect, so paths are
|
||||
// checked per-request once the HTTP request is visible.
|
||||
if path == "" {
|
||||
return false
|
||||
}
|
||||
for _, pattern := range rule.Paths {
|
||||
if matchPath(pattern, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchPath checks if a URL path matches a pattern.
|
||||
// Supports: exact ("/login"), prefix with wildcard ("/api/*"),
|
||||
// and contains ("*/admin/*"). A bare "*" matches everything.
|
||||
func matchPath(pattern, path string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
hasLeadingStar := strings.HasPrefix(pattern, "*")
|
||||
hasTrailingStar := strings.HasSuffix(pattern, "*")
|
||||
|
||||
switch {
|
||||
case hasLeadingStar && hasTrailingStar:
|
||||
// */admin/* = contains
|
||||
middle := strings.Trim(pattern, "*")
|
||||
return strings.Contains(path, middle)
|
||||
case hasTrailingStar:
|
||||
// /api/* = prefix
|
||||
prefix := strings.TrimSuffix(pattern, "*")
|
||||
return strings.HasPrefix(path, prefix)
|
||||
case hasLeadingStar:
|
||||
// *.json = suffix
|
||||
suffix := strings.TrimPrefix(pattern, "*")
|
||||
return strings.HasSuffix(path, suffix)
|
||||
default:
|
||||
// exact
|
||||
return path == pattern
|
||||
}
|
||||
}
|
||||
338
client/inspect/rules_test.go
Normal file
338
client/inspect/rules_test.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func testLogger() *log.Entry {
|
||||
return log.WithField("test", true)
|
||||
}
|
||||
|
||||
func mustDomain(t *testing.T, s string) domain.Domain {
|
||||
t.Helper()
|
||||
d, err := domain.FromString(s)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
func TestRuleEngine_Evaluate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []Rule
|
||||
defaultAction Action
|
||||
src netip.Addr
|
||||
dstDomain domain.Domain
|
||||
dstAddr netip.Addr
|
||||
dstPort uint16
|
||||
want Action
|
||||
}{
|
||||
{
|
||||
name: "no rules returns default allow",
|
||||
defaultAction: ActionAllow,
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "no rules returns default block",
|
||||
defaultAction: ActionBlock,
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "domain exact match blocks",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "malware.example.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard match blocks",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "phishing.evil.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard does not match base",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "evil.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "case insensitive domain match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "EXAMPLE.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "source CIDR match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Action: ActionInspect,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("192.168.1.50"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionInspect,
|
||||
},
|
||||
{
|
||||
name: "source CIDR no match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.5"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "destination network match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
Action: ActionInspect,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("192.168.1.1"),
|
||||
dstAddr: netip.MustParseAddr("10.50.0.1"),
|
||||
dstPort: 80,
|
||||
want: ActionInspect,
|
||||
},
|
||||
{
|
||||
name: "port match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Ports: []uint16{443, 8443},
|
||||
Action: ActionInspect,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionInspect,
|
||||
},
|
||||
{
|
||||
name: "port no match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Ports: []uint16{443, 8443},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 22,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "priority ordering first match wins",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("allow-internal"),
|
||||
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
|
||||
Action: ActionAllow,
|
||||
Priority: 1,
|
||||
},
|
||||
{
|
||||
ID: id.RuleID("inspect-all"),
|
||||
Action: ActionInspect,
|
||||
Priority: 10,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "api.internal.corp"),
|
||||
dstAddr: netip.MustParseAddr("10.1.0.5"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "all fields must match (AND logic)",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
|
||||
Ports: []uint16{443},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
// Source matches, domain matches, but port doesn't
|
||||
src: netip.MustParseAddr("192.168.1.10"),
|
||||
dstDomain: mustDomain(t, "phish.evil.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 8080,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "empty domain with domain rule does not match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "example.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: "", // raw IP connection, no SNI
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), tt.defaultAction)
|
||||
engine.UpdateRules(tt.rules, tt.defaultAction)
|
||||
|
||||
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleEngine_ProtocolMatching(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), ActionAllow)
|
||||
engine.UpdateRules([]Rule{
|
||||
{
|
||||
ID: "block-websocket",
|
||||
Protocols: []ProtoType{ProtoWebSocket},
|
||||
Action: ActionBlock,
|
||||
Priority: 1,
|
||||
},
|
||||
{
|
||||
ID: "inspect-h2",
|
||||
Protocols: []ProtoType{ProtoH2},
|
||||
Action: ActionInspect,
|
||||
Priority: 2,
|
||||
},
|
||||
}, ActionAllow)
|
||||
|
||||
src := netip.MustParseAddr("10.0.0.1")
|
||||
dst := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// WebSocket: blocked by rule
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
|
||||
|
||||
// HTTP/2: inspected by rule
|
||||
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
|
||||
|
||||
// Plain HTTP: no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
|
||||
|
||||
// HTTPS: no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
|
||||
|
||||
// QUIC/H3: no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
|
||||
|
||||
// Empty protocol (unknown): no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
|
||||
}
|
||||
|
||||
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), ActionAllow)
|
||||
engine.UpdateRules([]Rule{
|
||||
{
|
||||
ID: "block-all-protos",
|
||||
Action: ActionBlock,
|
||||
// No Protocols field = match all protocols
|
||||
Priority: 1,
|
||||
},
|
||||
}, ActionAllow)
|
||||
|
||||
src := netip.MustParseAddr("10.0.0.1")
|
||||
dst := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
|
||||
}
|
||||
|
||||
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), ActionAllow)
|
||||
|
||||
engine.UpdateRules([]Rule{
|
||||
{ID: "c", Priority: 30, Action: ActionBlock},
|
||||
{ID: "a", Priority: 10, Action: ActionInspect},
|
||||
{ID: "b", Priority: 20, Action: ActionAllow},
|
||||
}, ActionAllow)
|
||||
|
||||
engine.mu.RLock()
|
||||
defer engine.mu.RUnlock()
|
||||
|
||||
require.Len(t, engine.rules, 3)
|
||||
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
|
||||
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
|
||||
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
|
||||
}
|
||||
287
client/inspect/sni.go
Normal file
287
client/inspect/sni.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeHandshake = 0x16
|
||||
handshakeTypeClientHello = 0x01
|
||||
extensionTypeSNI = 0x0000
|
||||
extensionTypeALPN = 0x0010
|
||||
sniTypeHostName = 0x00
|
||||
|
||||
// maxClientHelloSize is the maximum ClientHello size we'll read.
|
||||
// Real-world ClientHellos are typically under 1KB but can reach ~16KB with
|
||||
// many extensions (post-quantum key shares, etc.).
|
||||
maxClientHelloSize = 16384
|
||||
)
|
||||
|
||||
// ClientHelloInfo holds data extracted from a TLS ClientHello.
|
||||
type ClientHelloInfo struct {
|
||||
SNI domain.Domain
|
||||
ALPN []string
|
||||
}
|
||||
|
||||
// isTLSHandshake reports whether the first byte indicates a TLS handshake record.
|
||||
func isTLSHandshake(b byte) bool {
|
||||
return b == recordTypeHandshake
|
||||
}
|
||||
|
||||
// httpMethods lists the first bytes of valid HTTP method tokens.
|
||||
var httpMethods = [][]byte{
|
||||
[]byte("GET "),
|
||||
[]byte("POST"),
|
||||
[]byte("PUT "),
|
||||
[]byte("DELE"),
|
||||
[]byte("HEAD"),
|
||||
[]byte("OPTI"),
|
||||
[]byte("PATC"),
|
||||
[]byte("CONN"),
|
||||
[]byte("TRAC"),
|
||||
}
|
||||
|
||||
// isHTTPMethod reports whether the peeked bytes look like the start of an HTTP request.
|
||||
func isHTTPMethod(b []byte) bool {
|
||||
if len(b) < 4 {
|
||||
return false
|
||||
}
|
||||
for _, m := range httpMethods {
|
||||
if b[0] == m[0] && b[1] == m[1] && b[2] == m[2] && b[3] == m[3] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseClientHello reads a TLS ClientHello from r and returns SNI and ALPN.
|
||||
func parseClientHello(r io.Reader) (ClientHelloInfo, error) {
|
||||
// TLS record header: type(1) + version(2) + length(2)
|
||||
var recordHeader [5]byte
|
||||
if _, err := io.ReadFull(r, recordHeader[:]); err != nil {
|
||||
return ClientHelloInfo{}, fmt.Errorf("read TLS record header: %w", err)
|
||||
}
|
||||
|
||||
if recordHeader[0] != recordTypeHandshake {
|
||||
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", recordHeader[0])
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(recordHeader[3:5]))
|
||||
if recordLen < 4 || recordLen > maxClientHelloSize {
|
||||
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
|
||||
}
|
||||
|
||||
// Read the full handshake message
|
||||
msg := make([]byte, recordLen)
|
||||
if _, err := io.ReadFull(r, msg); err != nil {
|
||||
return ClientHelloInfo{}, fmt.Errorf("read handshake message: %w", err)
|
||||
}
|
||||
|
||||
return parseClientHelloMsg(msg)
|
||||
}
|
||||
|
||||
// extractSNI reads a TLS ClientHello from r and returns the SNI hostname.
|
||||
// Returns empty domain if no SNI extension is present.
|
||||
func extractSNI(r io.Reader) (domain.Domain, error) {
|
||||
info, err := parseClientHello(r)
|
||||
return info.SNI, err
|
||||
}
|
||||
|
||||
// extractSNIFromBytes parses SNI from raw bytes that start with the TLS record header.
|
||||
func extractSNIFromBytes(data []byte) (domain.Domain, error) {
|
||||
info, err := parseClientHelloFromBytes(data)
|
||||
return info.SNI, err
|
||||
}
|
||||
|
||||
// parseClientHelloFromBytes parses a ClientHello from raw bytes starting with the TLS record header.
|
||||
func parseClientHelloFromBytes(data []byte) (ClientHelloInfo, error) {
|
||||
if len(data) < 5 {
|
||||
return ClientHelloInfo{}, fmt.Errorf("data too short for TLS record header")
|
||||
}
|
||||
|
||||
if data[0] != recordTypeHandshake {
|
||||
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", data[0])
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
if recordLen < 4 {
|
||||
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
|
||||
}
|
||||
|
||||
end := 5 + recordLen
|
||||
if end > len(data) {
|
||||
return ClientHelloInfo{}, fmt.Errorf("TLS record truncated: need %d, have %d", end, len(data))
|
||||
}
|
||||
|
||||
return parseClientHelloMsg(data[5:end])
|
||||
}
|
||||
|
||||
// parseClientHelloMsg extracts SNI and ALPN from a raw ClientHello handshake message.
|
||||
// msg starts at the handshake type byte.
|
||||
func parseClientHelloMsg(msg []byte) (ClientHelloInfo, error) {
|
||||
if len(msg) < 4 {
|
||||
return ClientHelloInfo{}, fmt.Errorf("handshake message too short")
|
||||
}
|
||||
|
||||
if msg[0] != handshakeTypeClientHello {
|
||||
return ClientHelloInfo{}, fmt.Errorf("not a ClientHello (type=%d)", msg[0])
|
||||
}
|
||||
|
||||
// Handshake header: type(1) + length(3)
|
||||
helloLen := int(msg[1])<<16 | int(msg[2])<<8 | int(msg[3])
|
||||
if helloLen+4 > len(msg) {
|
||||
return ClientHelloInfo{}, fmt.Errorf("ClientHello truncated")
|
||||
}
|
||||
|
||||
hello := msg[4 : 4+helloLen]
|
||||
return parseHelloBody(hello)
|
||||
}
|
||||
|
||||
// parseHelloBody parses the ClientHello body (after handshake header)
|
||||
// and extracts SNI and ALPN.
|
||||
func parseHelloBody(hello []byte) (ClientHelloInfo, error) {
|
||||
// ClientHello structure:
|
||||
// version(2) + random(32) + session_id_len(1) + session_id(var)
|
||||
// + cipher_suites_len(2) + cipher_suites(var)
|
||||
// + compression_len(1) + compression(var)
|
||||
// + extensions_len(2) + extensions(var)
|
||||
|
||||
var info ClientHelloInfo
|
||||
|
||||
if len(hello) < 35 {
|
||||
return info, fmt.Errorf("ClientHello body too short")
|
||||
}
|
||||
|
||||
pos := 2 + 32 // skip version + random
|
||||
|
||||
// Skip session ID
|
||||
if pos >= len(hello) {
|
||||
return info, fmt.Errorf("ClientHello truncated at session ID")
|
||||
}
|
||||
sessionIDLen := int(hello[pos])
|
||||
pos += 1 + sessionIDLen
|
||||
|
||||
// Skip cipher suites
|
||||
if pos+2 > len(hello) {
|
||||
return info, fmt.Errorf("ClientHello truncated at cipher suites")
|
||||
}
|
||||
cipherLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
|
||||
pos += 2 + cipherLen
|
||||
|
||||
// Skip compression methods
|
||||
if pos >= len(hello) {
|
||||
return info, fmt.Errorf("ClientHello truncated at compression")
|
||||
}
|
||||
compLen := int(hello[pos])
|
||||
pos += 1 + compLen
|
||||
|
||||
// Extensions
|
||||
if pos+2 > len(hello) {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
extLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
|
||||
pos += 2
|
||||
|
||||
extEnd := pos + extLen
|
||||
if extEnd > len(hello) {
|
||||
return info, fmt.Errorf("extensions block truncated")
|
||||
}
|
||||
|
||||
// Walk extensions looking for SNI and ALPN
|
||||
for pos+4 <= extEnd {
|
||||
extType := binary.BigEndian.Uint16(hello[pos : pos+2])
|
||||
extDataLen := int(binary.BigEndian.Uint16(hello[pos+2 : pos+4]))
|
||||
pos += 4
|
||||
|
||||
if pos+extDataLen > extEnd {
|
||||
return info, fmt.Errorf("extension data truncated")
|
||||
}
|
||||
|
||||
switch extType {
|
||||
case extensionTypeSNI:
|
||||
sni, err := parseSNIExtension(hello[pos : pos+extDataLen])
|
||||
if err != nil {
|
||||
return info, err
|
||||
}
|
||||
info.SNI = sni
|
||||
case extensionTypeALPN:
|
||||
info.ALPN = parseALPNExtension(hello[pos : pos+extDataLen])
|
||||
}
|
||||
|
||||
pos += extDataLen
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// parseALPNExtension parses the ALPN extension data and returns protocol names.
|
||||
// ALPN extension: list_length(2) + entries (each: len(1) + protocol_name(var))
|
||||
func parseALPNExtension(data []byte) []string {
|
||||
if len(data) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
listLen := int(binary.BigEndian.Uint16(data[0:2]))
|
||||
if listLen+2 > len(data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var protocols []string
|
||||
pos := 2
|
||||
end := 2 + listLen
|
||||
|
||||
for pos < end {
|
||||
if pos >= len(data) {
|
||||
break
|
||||
}
|
||||
nameLen := int(data[pos])
|
||||
pos++
|
||||
if pos+nameLen > end {
|
||||
break
|
||||
}
|
||||
protocols = append(protocols, string(data[pos:pos+nameLen]))
|
||||
pos += nameLen
|
||||
}
|
||||
|
||||
return protocols
|
||||
}
|
||||
|
||||
// parseSNIExtension parses the SNI extension data and returns the hostname.
|
||||
func parseSNIExtension(data []byte) (domain.Domain, error) {
|
||||
// SNI extension: list_length(2) + entries
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("SNI extension too short")
|
||||
}
|
||||
|
||||
listLen := int(binary.BigEndian.Uint16(data[0:2]))
|
||||
if listLen+2 > len(data) {
|
||||
return "", fmt.Errorf("SNI list truncated")
|
||||
}
|
||||
|
||||
pos := 2
|
||||
end := 2 + listLen
|
||||
|
||||
for pos+3 <= end {
|
||||
nameType := data[pos]
|
||||
nameLen := int(binary.BigEndian.Uint16(data[pos+1 : pos+3]))
|
||||
pos += 3
|
||||
|
||||
if pos+nameLen > end {
|
||||
return "", fmt.Errorf("SNI name truncated")
|
||||
}
|
||||
|
||||
if nameType == sniTypeHostName {
|
||||
hostname := string(data[pos : pos+nameLen])
|
||||
return domain.FromString(hostname)
|
||||
}
|
||||
|
||||
pos += nameLen
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
109
client/inspect/sni_test.go
Normal file
109
client/inspect/sni_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractSNI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sni string
|
||||
wantSNI string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "standard domain",
|
||||
sni: "example.com",
|
||||
wantSNI: "example.com",
|
||||
},
|
||||
{
|
||||
name: "subdomain",
|
||||
sni: "api.staging.example.com",
|
||||
wantSNI: "api.staging.example.com",
|
||||
},
|
||||
{
|
||||
name: "mixed case normalized to lowercase",
|
||||
sni: "Example.COM",
|
||||
wantSNI: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientHello := buildClientHello(t, tt.sni)
|
||||
|
||||
sni, err := extractSNI(bytes.NewReader(clientHello))
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantSNI, sni.PunycodeString())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSNI_NotTLS(t *testing.T) {
|
||||
// HTTP request instead of TLS
|
||||
data := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
_, err := extractSNI(bytes.NewReader(data))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not a TLS handshake")
|
||||
}
|
||||
|
||||
func TestExtractSNI_Truncated(t *testing.T) {
|
||||
// Just the record header, no body
|
||||
data := []byte{0x16, 0x03, 0x01, 0x00, 0x05}
|
||||
_, err := extractSNI(bytes.NewReader(data))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExtractSNIFromBytes(t *testing.T) {
|
||||
clientHello := buildClientHello(t, "test.example.com")
|
||||
|
||||
sni, err := extractSNIFromBytes(clientHello)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test.example.com", sni.PunycodeString())
|
||||
}
|
||||
|
||||
// buildClientHello generates a real TLS ClientHello with the given SNI.
|
||||
func buildClientHello(t *testing.T, serverName string) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Use a pipe to capture the ClientHello bytes
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
done := make(chan []byte, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
n, _ := serverConn.Read(buf)
|
||||
done <- buf[:n]
|
||||
serverConn.Close()
|
||||
}()
|
||||
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
|
||||
// Trigger the handshake (will fail since server isn't TLS, but we capture the ClientHello)
|
||||
go func() {
|
||||
_ = tlsConn.Handshake()
|
||||
tlsConn.Close()
|
||||
}()
|
||||
|
||||
clientHello := <-done
|
||||
clientConn.Close()
|
||||
|
||||
require.True(t, len(clientHello) > 5, "ClientHello too short")
|
||||
require.Equal(t, byte(0x16), clientHello[0], "not a TLS handshake record")
|
||||
|
||||
return clientHello
|
||||
}
|
||||
287
client/inspect/tls.go
Normal file
287
client/inspect/tls.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// handleTLS processes a TLS connection for the kernel-mode path: extracts SNI,
|
||||
// evaluates rules, and handles the connection internally.
|
||||
// In envoy mode, allowed connections are forwarded to envoy instead of direct relay.
|
||||
func (p *Proxy) handleTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
|
||||
result, err := p.inspectTLS(ctx, pconn, dst, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result.PassthroughConn != nil {
|
||||
p.mu.RLock()
|
||||
envoy := p.envoy
|
||||
p.mu.RUnlock()
|
||||
|
||||
if envoy != nil {
|
||||
return p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
}
|
||||
return p.tlsPassthrough(ctx, pconn, dst, "")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// inspectTLS extracts SNI, evaluates rules, and returns the result.
|
||||
// For ActionAllow: returns the peekConn as PassthroughConn (caller relays).
|
||||
// For ActionBlock/ActionInspect: handles internally and returns nil PassthroughConn.
|
||||
func (p *Proxy) inspectTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
|
||||
// The first 5 bytes (TLS record header) are already peeked.
|
||||
// Extend to read the full TLS record so bytes remain in the buffer for passthrough.
|
||||
peeked := pconn.Peeked()
|
||||
recordLen := int(peeked[3])<<8 | int(peeked[4])
|
||||
if _, err := pconn.PeekMore(5 + recordLen); err != nil {
|
||||
return InspectResult{}, fmt.Errorf("read TLS record: %w", err)
|
||||
}
|
||||
|
||||
hello, err := parseClientHelloFromBytes(pconn.Peeked())
|
||||
if err != nil {
|
||||
return InspectResult{}, fmt.Errorf("parse ClientHello: %w", err)
|
||||
}
|
||||
|
||||
sni := hello.SNI
|
||||
proto := protoFromALPN(hello.ALPN)
|
||||
// Connection-level evaluation: pass empty path.
|
||||
action := p.evaluateAction(src.IP, sni, dst, proto, "")
|
||||
|
||||
// If any rule for this domain has path patterns, force inspect so paths can
|
||||
// be checked per-request after MITM decryption.
|
||||
if action == ActionAllow && p.rules.HasPathRulesForDomain(sni) {
|
||||
p.log.Debugf("upgrading to inspect for %s (path rules exist)", sni.PunycodeString())
|
||||
action = ActionInspect
|
||||
}
|
||||
|
||||
// Snapshot cert provider under lock for use in this connection.
|
||||
p.mu.RLock()
|
||||
certs := p.certs
|
||||
p.mu.RUnlock()
|
||||
|
||||
switch action {
|
||||
case ActionBlock:
|
||||
p.log.Debugf("block: TLS to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
if certs != nil {
|
||||
return InspectResult{Action: ActionBlock}, p.tlsBlockPage(ctx, pconn, sni, certs)
|
||||
}
|
||||
return InspectResult{Action: ActionBlock}, ErrBlocked
|
||||
|
||||
case ActionAllow:
|
||||
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
|
||||
|
||||
case ActionInspect:
|
||||
if certs == nil {
|
||||
p.log.Warnf("allow: %s (inspect requested but no MITM CA configured)", sni.PunycodeString())
|
||||
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
|
||||
}
|
||||
err := p.tlsMITM(ctx, pconn, dst, sni, src, certs)
|
||||
return InspectResult{Action: ActionInspect}, err
|
||||
|
||||
default:
|
||||
p.log.Warnf("block: unknown action %q for %s", action, sni.PunycodeString())
|
||||
return InspectResult{Action: ActionBlock}, ErrBlocked
|
||||
}
|
||||
}
|
||||
|
||||
// tlsBlockPage completes a MITM TLS handshake with the client using a dynamic
|
||||
// certificate, then serves an HTTP 403 block page so the user sees a clear
|
||||
// message instead of a cryptic SSL error.
|
||||
func (p *Proxy) tlsBlockPage(ctx context.Context, pconn *peekConn, sni domain.Domain, certs *CertProvider) error {
|
||||
hostname := sni.PunycodeString()
|
||||
|
||||
// Force HTTP/1.1 only: block pages are simple responses, no need for h2
|
||||
tlsCfg := certs.GetTLSConfig()
|
||||
tlsCfg.NextProtos = []string{"http/1.1"}
|
||||
clientTLS := tls.Server(pconn, tlsCfg)
|
||||
if err := clientTLS.HandshakeContext(ctx); err != nil {
|
||||
// Client may not trust our CA, handshake fails. That's expected.
|
||||
return fmt.Errorf("block page TLS handshake for %s: %w", hostname, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := clientTLS.Close(); err != nil {
|
||||
p.log.Debugf("close block page TLS for %s: %v", hostname, err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeBlockResponse(clientTLS, nil, sni)
|
||||
return ErrBlocked
|
||||
}
|
||||
|
||||
// tlsPassthrough connects to the destination and relays encrypted traffic
|
||||
// without decryption. The peeked ClientHello bytes are replayed.
|
||||
func (p *Proxy) tlsPassthrough(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain) error {
|
||||
remote, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", dst, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remote.Close(); err != nil {
|
||||
p.log.Debugf("close remote for %s: %v", dst, err)
|
||||
}
|
||||
}()
|
||||
|
||||
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
|
||||
return relay(ctx, pconn, remote)
|
||||
}
|
||||
|
||||
// tlsMITM terminates the client TLS connection with a dynamic certificate,
|
||||
// establishes a new TLS connection to the real destination, and runs the
|
||||
// HTTP inspection pipeline on the decrypted traffic.
|
||||
func (p *Proxy) tlsMITM(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, certs *CertProvider) error {
|
||||
hostname := sni.PunycodeString()
|
||||
|
||||
// TLS handshake with client using dynamic cert
|
||||
clientTLS := tls.Server(pconn, certs.GetTLSConfig())
|
||||
if err := clientTLS.HandshakeContext(ctx); err != nil {
|
||||
return fmt.Errorf("client TLS handshake for %s: %w", hostname, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := clientTLS.Close(); err != nil {
|
||||
p.log.Debugf("close client TLS for %s: %v", hostname, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// TLS connection to real destination
|
||||
remoteTLS, err := p.dialTLS(ctx, dst, hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial TLS %s (%s): %w", dst, hostname, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remoteTLS.Close(); err != nil {
|
||||
p.log.Debugf("close remote TLS for %s: %v", hostname, err)
|
||||
}
|
||||
}()
|
||||
|
||||
negotiatedProto := clientTLS.ConnectionState().NegotiatedProtocol
|
||||
p.log.Tracef("inspect: MITM established for %s (proto=%s)", hostname, negotiatedProto)
|
||||
|
||||
return p.inspectHTTP(ctx, clientTLS, remoteTLS, dst, sni, src, negotiatedProto)
|
||||
}
|
||||
|
||||
// dialTLS connects to the destination with TLS, verifying the real server certificate.
|
||||
func (p *Proxy) dialTLS(ctx context.Context, dst netip.AddrPort, serverName string) (net.Conn, error) {
|
||||
rawConn, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(rawConn, &tls.Config{
|
||||
ServerName: serverName,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
if closeErr := rawConn.Close(); closeErr != nil {
|
||||
p.log.Debugf("close raw conn after TLS handshake failure: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("TLS handshake with %s: %w", serverName, err)
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// protoFromALPN maps TLS ALPN protocol names to proxy ProtoType.
|
||||
// Falls back to ProtoHTTPS when no recognized ALPN is present.
|
||||
func protoFromALPN(alpn []string) ProtoType {
|
||||
for _, p := range alpn {
|
||||
switch p {
|
||||
case "h2":
|
||||
return ProtoH2
|
||||
case "h3": // unlikely in TLS, but handle anyway
|
||||
return ProtoH3
|
||||
}
|
||||
}
|
||||
// No ALPN or only "http/1.1": treat as HTTPS
|
||||
return ProtoHTTPS
|
||||
}
|
||||
|
||||
// relay copies data bidirectionally between client and remote until one
|
||||
// side closes or the context is cancelled.
|
||||
func relay(ctx context.Context, client, remote net.Conn) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(remote, client)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(client, remote)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
var firstErr error
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil && firstErr == nil {
|
||||
if !isClosedErr(err) {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// evaluateAction runs rule evaluation and resolves the effective action.
|
||||
// Pass empty path for connection-level (TLS), non-empty for request-level (HTTP).
|
||||
func (p *Proxy) evaluateAction(src netip.Addr, sni domain.Domain, dst netip.AddrPort, proto ProtoType, path string) Action {
|
||||
return p.rules.Evaluate(src, sni, dst.Addr(), dst.Port(), proto, path)
|
||||
}
|
||||
|
||||
// dialTCP dials the destination, blocking connections to loopback, link-local,
|
||||
// multicast, and WG overlay network addresses.
|
||||
func (p *Proxy) dialTCP(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
ip := dst.Addr().Unmap()
|
||||
if err := p.validateDialTarget(ip); err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", dst, err)
|
||||
}
|
||||
return p.dialer.DialContext(ctx, "tcp", dst.String())
|
||||
}
|
||||
|
||||
// validateDialTarget blocks destinations that should never be dialed by the proxy.
|
||||
// Mirrors the route validation in systemops.validateRoute.
|
||||
func (p *Proxy) validateDialTarget(addr netip.Addr) error {
|
||||
switch {
|
||||
case !addr.IsValid():
|
||||
return fmt.Errorf("invalid address")
|
||||
case addr.IsLoopback():
|
||||
return fmt.Errorf("loopback address not allowed")
|
||||
case addr.IsLinkLocalUnicast(), addr.IsLinkLocalMulticast(), addr.IsInterfaceLocalMulticast():
|
||||
return fmt.Errorf("link-local address not allowed")
|
||||
case addr.IsMulticast():
|
||||
return fmt.Errorf("multicast address not allowed")
|
||||
case p.wgNetwork.IsValid() && p.wgNetwork.Contains(addr):
|
||||
return fmt.Errorf("overlay network address not allowed")
|
||||
case p.localIPs != nil && p.localIPs.IsLocalIP(addr):
|
||||
return fmt.Errorf("local address not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isClosedErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return err == io.EOF ||
|
||||
err == io.ErrClosedPipe ||
|
||||
err == net.ErrClosed ||
|
||||
err == context.Canceled
|
||||
}
|
||||
@@ -19,6 +19,9 @@ import (
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
func TestDefaultManagerStateless(t *testing.T) {
|
||||
// stateless currently only in userspace, so we have to disable kernel
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
// This tests the full ACL manager -> uspfilter integration.
|
||||
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||
// up when they're removed from the network map in a subsequent update.
|
||||
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||
// one added without leaking.
|
||||
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
config := &PKCEAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||
Scope: protoConfig.GetScope(),
|
||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
config := &DeviceAuthProviderConfig{
|
||||
Audience: protoConfig.GetAudience(),
|
||||
ClientID: protoConfig.GetClientID(),
|
||||
ClientSecret: protoConfig.GetClientSecret(),
|
||||
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||
Domain: protoConfig.Domain,
|
||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -44,6 +44,10 @@ import (
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// androidRunOverride is set on Android to inject mobile dependencies
|
||||
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
@@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
if androidRunOverride != nil {
|
||||
return androidRunOverride(c, runningChan, logPath)
|
||||
}
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
}
|
||||
|
||||
@@ -104,6 +111,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -113,6 +121,7 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
@@ -553,6 +562,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
|
||||
InspectionCACertPath: config.InspectionCACertPath,
|
||||
InspectionCAKeyPath: config.InspectionCAKeyPath,
|
||||
|
||||
ProfileConfig: config,
|
||||
}
|
||||
|
||||
@@ -610,12 +622,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
@@ -634,12 +640,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
|
||||
73
client/internal/connect_android_default.go
Normal file
73
client/internal/connect_android_default.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
||||
// It returns an empty interface list, which means ICE P2P candidates won't be
|
||||
// discovered — connections will fall back to relay. Applications that need P2P
|
||||
// should provide a real implementation via runOnAndroidEmbed that uses
|
||||
// Android's ConnectivityManager to enumerate network interfaces.
|
||||
type noopIFaceDiscover struct{}
|
||||
|
||||
func (noopIFaceDiscover) IFaces() (string, error) {
|
||||
// Return empty JSON array — no local interfaces advertised for ICE.
|
||||
// This is intentional: without Android's ConnectivityManager, we cannot
|
||||
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
||||
// Relay connections still work; only P2P hole-punching is disabled.
|
||||
return "[]", nil
|
||||
}
|
||||
|
||||
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
||||
// Network change events are ignored since the embed client manages its own
|
||||
// reconnection logic via the engine's built-in retry mechanism.
|
||||
type noopNetworkChangeListener struct{}
|
||||
|
||||
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
||||
// No-op: embed.Client relies on the engine's internal reconnection
|
||||
// logic rather than OS-level network change notifications.
|
||||
}
|
||||
|
||||
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
||||
// network stack, not by OS-level interface configuration.
|
||||
}
|
||||
|
||||
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||
// DNS readiness notifications are not needed in netstack/embed mode
|
||||
// since system DNS is disabled and DNS resolution happens externally.
|
||||
type noopDnsReadyListener struct{}
|
||||
|
||||
func (noopDnsReadyListener) OnReady() {
|
||||
// No-op: embed.Client does not need DNS readiness notifications.
|
||||
// System DNS is disabled in netstack mode.
|
||||
}
|
||||
|
||||
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
||||
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||
|
||||
func init() {
|
||||
// Wire up the default override so embed.Client.Start() works on Android
|
||||
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||
// dependencies so the engine's existing Android code paths work unchanged.
|
||||
// Applications that need P2P ICE or real DNS should replace this by
|
||||
// setting androidRunOverride before calling Start().
|
||||
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||
return c.runOnAndroidEmbed(
|
||||
noopIFaceDiscover{},
|
||||
noopNetworkChangeListener{},
|
||||
[]netip.AddrPort{},
|
||||
noopDnsReadyListener{},
|
||||
runningChan,
|
||||
logPath,
|
||||
)
|
||||
}
|
||||
}
|
||||
32
client/internal/connect_android_embed.go
Normal file
32
client/internal/connect_android_embed.go
Normal file
@@ -0,0 +1,32 @@
|
||||
//go:build android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
)
|
||||
|
||||
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||
// so embed.Client.Start() can detect when the engine is ready.
|
||||
// It provides complete MobileDependency so the engine's existing
|
||||
// Android code paths work unchanged.
|
||||
func (c *ConnectClient) runOnAndroidEmbed(
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
runningChan chan struct{},
|
||||
logPath string,
|
||||
) error {
|
||||
mobileDependency := MobileDependency{
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return c.run(mobileDependency, runningChan, logPath)
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
|
||||
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addServiceParams(); err != nil {
|
||||
log.Errorf("failed to add service params to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addMetrics(); err != nil {
|
||||
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||
}
|
||||
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
serviceParamsFile = "service.json"
|
||||
serviceParamsBundle = "service_params.json"
|
||||
maskedValue = "***"
|
||||
envVarPrefix = "NB_"
|
||||
jsonKeyManagementURL = "management_url"
|
||||
jsonKeyServiceEnv = "service_env_vars"
|
||||
)
|
||||
|
||||
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
|
||||
|
||||
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
|
||||
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
|
||||
func (g *BundleGenerator) addServiceParams() error {
|
||||
path := filepath.Join(configs.StateDir, serviceParamsFile)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read service params: %w", err)
|
||||
}
|
||||
|
||||
var params map[string]any
|
||||
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||
return fmt.Errorf("parse service params: %w", err)
|
||||
}
|
||||
|
||||
if g.anonymize {
|
||||
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
|
||||
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
|
||||
}
|
||||
}
|
||||
|
||||
g.sanitizeServiceEnvVars(params)
|
||||
|
||||
sanitizedData, err := json.MarshalIndent(params, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sanitized service params: %w", err)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
|
||||
return fmt.Errorf("add service params to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
|
||||
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
|
||||
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
|
||||
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
|
||||
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sanitized := make(map[string]any, len(envVars))
|
||||
for k, v := range envVars {
|
||||
val, _ := v.(string)
|
||||
switch {
|
||||
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
|
||||
sanitized[k] = maskedValue
|
||||
case g.anonymize:
|
||||
sanitized[k] = g.anonymizer.AnonymizeString(val)
|
||||
default:
|
||||
sanitized[k] = val
|
||||
}
|
||||
}
|
||||
params[jsonKeyServiceEnv] = sanitized
|
||||
}
|
||||
|
||||
// isSensitiveEnvVar returns true for env var names that may contain secrets.
|
||||
func isSensitiveEnvVar(key string) bool {
|
||||
lower := strings.ToLower(key)
|
||||
for _, s := range sensitiveEnvSubstrings {
|
||||
if strings.Contains(lower, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
|
||||
configContent.WriteString("NetBird Client Configuration:\n\n")
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -10,6 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSensitiveEnvVar(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
sensitive bool
|
||||
}{
|
||||
{"NB_SETUP_KEY", true},
|
||||
{"NB_API_TOKEN", true},
|
||||
{"NB_CLIENT_SECRET", true},
|
||||
{"NB_PASSWORD", true},
|
||||
{"NB_CREDENTIAL", true},
|
||||
{"NB_LOG_LEVEL", false},
|
||||
{"NB_MANAGEMENT_URL", false},
|
||||
{"NB_HOSTNAME", false},
|
||||
{"HOME", false},
|
||||
{"PATH", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.key, func(t *testing.T) {
|
||||
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeServiceEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
anonymize bool
|
||||
input map[string]any
|
||||
check func(t *testing.T, params map[string]any)
|
||||
}{
|
||||
{
|
||||
name: "no env vars key",
|
||||
anonymize: false,
|
||||
input: map[string]any{"management_url": "https://mgmt.example.com"},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
|
||||
_, ok := params[jsonKeyServiceEnv]
|
||||
assert.False(t, ok, "service_env_vars should not be added")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"HOME": "/root",
|
||||
"PATH": "/usr/bin",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
|
||||
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sensitive NB vars are masked",
|
||||
anonymize: false,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_SETUP_KEY": "abc123",
|
||||
"NB_API_TOKEN": "tok_xyz",
|
||||
"NB_LOG_LEVEL": "info",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
|
||||
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safe NB vars anonymized when anonymize is true",
|
||||
anonymize: true,
|
||||
input: map[string]any{
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
|
||||
"NB_LOG_LEVEL": "debug",
|
||||
"NB_SETUP_KEY": "secret",
|
||||
"SOME_OTHER": "val",
|
||||
},
|
||||
},
|
||||
check: func(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
env := params[jsonKeyServiceEnv].(map[string]any)
|
||||
// Safe NB_ values should be anonymized (not the original, not masked)
|
||||
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
|
||||
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
|
||||
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
|
||||
|
||||
logVal := env["NB_LOG_LEVEL"].(string)
|
||||
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
|
||||
|
||||
// Sensitive and non-NB_ still masked
|
||||
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
|
||||
assert.Equal(t, maskedValue, env["SOME_OTHER"])
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
g := &BundleGenerator{
|
||||
anonymize: tt.anonymize,
|
||||
anonymizer: anonymizer,
|
||||
}
|
||||
g.sanitizeServiceEnvVars(tt.input)
|
||||
tt.check(t, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddServiceParams(t *testing.T) {
|
||||
t.Run("missing service.json returns nil", func(t *testing.T) {
|
||||
g := &BundleGenerator{
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
}
|
||||
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = t.TempDir()
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
err := g.addServiceParams()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
jsonKeyServiceEnv: map[string]any{
|
||||
"NB_LOG_LEVEL": "trace",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: true,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, zr.File, 1)
|
||||
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
mgmt := result[jsonKeyManagementURL].(string)
|
||||
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
|
||||
assert.NotEmpty(t, mgmt)
|
||||
|
||||
env := result[jsonKeyServiceEnv].(map[string]any)
|
||||
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
|
||||
})
|
||||
|
||||
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
origStateDir := configs.StateDir
|
||||
configs.StateDir = dir
|
||||
t.Cleanup(func() { configs.StateDir = origStateDir })
|
||||
|
||||
input := map[string]any{
|
||||
jsonKeyManagementURL: "https://api.example.com:443",
|
||||
}
|
||||
data, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
g := &BundleGenerator{
|
||||
anonymize: false,
|
||||
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
|
||||
archive: zw,
|
||||
}
|
||||
|
||||
require.NoError(t, g.addServiceParams())
|
||||
require.NoError(t, zw.Close())
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
rc, err := zr.File[0].Open()
|
||||
require.NoError(t, err)
|
||||
defer rc.Close()
|
||||
|
||||
var result map[string]any
|
||||
require.NoError(t, json.NewDecoder(rc).Decode(&result))
|
||||
|
||||
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
|
||||
@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||
return nil
|
||||
}
|
||||
w.response = m
|
||||
if m.MsgHdr.Truncated {
|
||||
w.SetMeta("truncated", "true")
|
||||
}
|
||||
return w.ResponseWriter.WriteMsg(m)
|
||||
}
|
||||
|
||||
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := resutil.GenerateRequestID()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": requestID,
|
||||
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
question := r.Question[0]
|
||||
qname := strings.ToLower(question.Name)
|
||||
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
meta += " " + k + "=" + v
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
|
||||
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||
meta, time.Since(startTime))
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
|
||||
@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestLocalResolver_Stop tests cleanup on Stop
|
||||
// TestLocalResolver_Stop tests cleanup on GracefullyStop
|
||||
func TestLocalResolver_Stop(t *testing.T) {
|
||||
t.Run("Stop clears all state", func(t *testing.T) {
|
||||
t.Run("GracefullyStop clears all state", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||
})
|
||||
|
||||
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: "example.com.",
|
||||
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
|
||||
resolver.Stop()
|
||||
})
|
||||
|
||||
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
lookupStarted := make(chan struct{})
|
||||
|
||||
@@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// SetFirewall mock implementation of SetFirewall from Server interface
|
||||
func (m *MockServer) SetFirewall(Firewall) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||
func (m *MockServer) BeginBatch() {
|
||||
// Mock implementation - no-op
|
||||
|
||||
@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
|
||||
// After a call to Hijack(), the DNS package will not do anything with the connection.
|
||||
func (r *responseWriter) Hijack() {
|
||||
}
|
||||
|
||||
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
|
||||
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
|
||||
var srcIP net.IP
|
||||
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
|
||||
srcIP = ipv4.(*layers.IPv4).SrcIP
|
||||
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
|
||||
srcIP = ipv6.(*layers.IPv6).SrcIP
|
||||
}
|
||||
|
||||
var srcPort int
|
||||
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
|
||||
srcPort = int(udp.(*layers.UDP).SrcPort)
|
||||
}
|
||||
|
||||
if srcIP == nil {
|
||||
return nil
|
||||
}
|
||||
return &net.UDPAddr{IP: srcIP, Port: srcPort}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ type Server interface {
|
||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||
SetRouteChecker(func(netip.Addr) bool)
|
||||
SetFirewall(Firewall)
|
||||
}
|
||||
|
||||
type nsGroupsByDomain struct {
|
||||
@@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
|
||||
if config.WgInterface.IsUserspaceBind() {
|
||||
dnsService = NewServiceViaMemory(config.WgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort)
|
||||
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
|
||||
}
|
||||
|
||||
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
|
||||
@@ -186,11 +187,16 @@ func NewDefaultServerIos(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
hostsDnsList []netip.AddrPort,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
return ds
|
||||
}
|
||||
|
||||
@@ -374,6 +380,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// SetFirewall sets the firewall used for DNS port DNAT rules.
|
||||
// This must be called before Initialize when using the listener-based service,
|
||||
// because the firewall is typically not available at construction time.
|
||||
func (s *DefaultServer) SetFirewall(fw Firewall) {
|
||||
if svc, ok := s.service.(*serviceViaListener); ok {
|
||||
svc.listenerFlagLock.Lock()
|
||||
svc.firewall = fw
|
||||
svc.listenerFlagLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
@@ -395,8 +412,12 @@ func (s *DefaultServer) Stop() {
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) disableDNS() error {
|
||||
defer s.service.Stop()
|
||||
func (s *DefaultServer) disableDNS() (retErr error) {
|
||||
defer func() {
|
||||
if err := s.service.Stop(); err != nil {
|
||||
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
if s.isUsingNoopHostManager() {
|
||||
return nil
|
||||
|
||||
@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
|
||||
type mockService struct{}
|
||||
|
||||
func (m *mockService) Listen() error { return nil }
|
||||
func (m *mockService) Stop() {}
|
||||
func (m *mockService) Stop() error { return nil }
|
||||
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
|
||||
func (m *mockService) RuntimePort() int { return 53 }
|
||||
func (m *mockService) RegisterMux(string, dns.Handler) {}
|
||||
|
||||
@@ -4,15 +4,25 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPort = 53
|
||||
)
|
||||
|
||||
// Firewall provides DNAT capabilities for DNS port redirection.
|
||||
// This is used when the DNS server cannot bind port 53 directly
|
||||
// and needs firewall rules to redirect traffic.
|
||||
type Firewall interface {
|
||||
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
Stop() error
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
|
||||
@@ -10,9 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
@@ -31,25 +35,33 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
tcpServer *dns.Server
|
||||
listenIP netip.Addr
|
||||
listenPort uint16
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
firewall Firewall
|
||||
tcpDNATConfigured bool
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
firewall: fw,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
tcpServer: &dns.Server{
|
||||
Net: "tcp",
|
||||
Handler: mux,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
s.server.Addr = addr
|
||||
s.tcpServer.Addr = addr
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
log.Debugf("starting dns on %s (UDP + TCP)", addr)
|
||||
s.listenerIsRunning = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
|
||||
s.listenerFlagLock.Lock()
|
||||
unexpected := s.listenerIsRunning
|
||||
s.listenerIsRunning = false
|
||||
s.listenerFlagLock.Unlock()
|
||||
|
||||
if unexpected {
|
||||
if err := s.tcpServer.Shutdown(); err != nil {
|
||||
log.Debugf("failed to shutdown DNS TCP server: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if err := s.tcpServer.ListenAndServe(); err != nil {
|
||||
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
|
||||
// a DNAT rule because eBPF only handles UDP.
|
||||
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
|
||||
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
|
||||
} else {
|
||||
s.tcpDNATConfigured = true
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
func (s *serviceViaListener) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
s.listenerIsRunning = false
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := s.server.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
|
||||
}
|
||||
|
||||
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
|
||||
}
|
||||
|
||||
if s.tcpDNATConfigured && s.firewall != nil {
|
||||
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
s.tcpDNATConfigured = false
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
if err := s.ebpfService.FreeDNSFwd(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
// evalListenAddress figure out the listen address for the DNS server
|
||||
// first check the 53 port availability on WG interface or lo, if not success
|
||||
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(port))
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(addrPort)
|
||||
udpLn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
if err := udpLn.Close(); err != nil {
|
||||
log.Debugf("close UDP probe listener: %s", err)
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
|
||||
return false
|
||||
}
|
||||
if err := tcpLn.Close(); err != nil {
|
||||
log.Debugf("close TCP probe listener: %s", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
86
client/internal/dns/service_listener_test.go
Normal file
86
client/internal/dns/service_listener_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("192.0.2.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a service using a custom address to avoid needing root
|
||||
svc := newServiceViaListener(nil, nil, nil)
|
||||
svc.dnsMux.Handle(".", handler)
|
||||
|
||||
// Bind both transports up front to avoid TOCTOU races.
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Skip("cannot bind to 127.0.0.153, skipping")
|
||||
}
|
||||
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
|
||||
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
t.Skip("cannot bind TCP on same port, skipping")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", customIP, port)
|
||||
svc.server.PacketConn = udpConn
|
||||
svc.tcpServer.Listener = tcpLn
|
||||
svc.listenIP = customIP
|
||||
svc.listenPort = port
|
||||
|
||||
go func() {
|
||||
if err := svc.server.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := svc.tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
svc.listenerIsRunning = true
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, svc.Stop())
|
||||
}()
|
||||
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
// Test UDP query
|
||||
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
|
||||
udpResp, _, err := udpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "UDP query should succeed")
|
||||
require.NotNil(t, udpResp)
|
||||
require.NotEmpty(t, udpResp.Answer)
|
||||
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
|
||||
|
||||
// Test TCP query
|
||||
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
|
||||
tcpResp, _, err := tcpClient.Exchange(q, addr)
|
||||
require.NoError(t, err, "TCP query should succeed")
|
||||
require.NotNil(t, tcpResp)
|
||||
require.NotEmpty(t, tcpResp.Answer)
|
||||
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP netip.Addr
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
tcpDNS *tcpDNSServer
|
||||
tcpHookSet bool
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
|
||||
if err != nil {
|
||||
log.Errorf("get last ip from network: %v", err)
|
||||
}
|
||||
s := &ServiceViaMemory{
|
||||
|
||||
return &ServiceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: lastIP,
|
||||
runtimePort: DefaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Listen() error {
|
||||
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return fmt.Errorf("filter dns traffice: %w", err)
|
||||
if err := s.filterDNSTraffic(); err != nil {
|
||||
return fmt.Errorf("filter dns traffic: %w", err)
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) Stop() {
|
||||
func (s *ServiceViaMemory) Stop() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter != nil {
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
if s.tcpHookSet {
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
s.tcpDNS.Stop()
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
func (s *ServiceViaMemory) filterDNSTraffic() error {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
return errors.New("DNS filter not initialized")
|
||||
}
|
||||
|
||||
// Create TCP DNS server lazily here since the device may not exist at construction time.
|
||||
if s.tcpDNS == nil {
|
||||
if dev := s.wgInterface.GetDevice(); dev != nil {
|
||||
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
|
||||
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
|
||||
}
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
if udpLayer == nil {
|
||||
return true
|
||||
}
|
||||
udp, ok := udpLayer.(*layers.UDP)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
dev := s.wgInterface.GetDevice()
|
||||
if dev == nil {
|
||||
return true
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
|
||||
writer := &responseWriter{
|
||||
remote: remoteAddrFromPacket(packet),
|
||||
packet: packet,
|
||||
device: dev.Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
|
||||
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
|
||||
|
||||
if s.tcpDNS != nil {
|
||||
tcpHook := func(packetData []byte) bool {
|
||||
s.tcpDNS.InjectPacket(packetData)
|
||||
return true
|
||||
}
|
||||
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
|
||||
s.tcpHookSet = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
444
client/internal/dns/tcpstack.go
Normal file
444
client/internal/dns/tcpstack.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsTCPReceiveWindow = 8192
|
||||
dnsTCPMaxInFlight = 16
|
||||
dnsTCPIdleTimeout = 30 * time.Second
|
||||
dnsTCPReadTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
|
||||
// It is started lazily when a truncated DNS response is detected and shuts down
|
||||
// after a period of inactivity to conserve resources.
|
||||
type tcpDNSServer struct {
|
||||
mu sync.Mutex
|
||||
s *stack.Stack
|
||||
ep *dnsEndpoint
|
||||
mux *dns.ServeMux
|
||||
tunDev tun.Device
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
mtu uint16
|
||||
|
||||
running bool
|
||||
closed bool
|
||||
timerID uint64
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
|
||||
return &tcpDNSServer{
|
||||
mux: mux,
|
||||
tunDev: tunDev,
|
||||
ip: ip,
|
||||
port: port,
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// InjectPacket ensures the stack is running and delivers a raw IP packet into
|
||||
// the gvisor stack for TCP processing. Combining both operations under a single
|
||||
// lock prevents a race where the idle timer could stop the stack between
|
||||
// start and delivery.
|
||||
func (t *tcpDNSServer) InjectPacket(payload []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if !t.running {
|
||||
if err := t.startLocked(); err != nil {
|
||||
log.Errorf("failed to start TCP DNS stack: %v", err)
|
||||
return
|
||||
}
|
||||
t.running = true
|
||||
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
|
||||
}
|
||||
t.resetTimerLocked()
|
||||
|
||||
ep := t.ep
|
||||
if ep == nil || ep.dispatcher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(payload),
|
||||
})
|
||||
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
|
||||
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||
}
|
||||
|
||||
// Stop tears down the gvisor stack and releases resources permanently.
|
||||
// After Stop, InjectPacket becomes a no-op.
|
||||
func (t *tcpDNSServer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.stopLocked()
|
||||
t.closed = true
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) startLocked() error {
|
||||
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
nicID := tcpip.NICID(1)
|
||||
ep := &dnsEndpoint{
|
||||
tunDev: t.tunDev,
|
||||
}
|
||||
ep.mtu.Store(uint32(t.mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, ep); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
|
||||
PrefixLen: 32,
|
||||
},
|
||||
}
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("add protocol address: %s", err)
|
||||
}
|
||||
|
||||
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set promiscuous mode: %s", err)
|
||||
}
|
||||
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("set spoofing: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||
)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
s.Wait()
|
||||
return fmt.Errorf("create default subnet: %w", err)
|
||||
}
|
||||
|
||||
s.SetRouteTable([]tcpip.Route{
|
||||
{Destination: defaultSubnet, NIC: nicID},
|
||||
})
|
||||
|
||||
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
|
||||
t.handleTCPDNS(r)
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
|
||||
|
||||
t.s = s
|
||||
t.ep = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) stopLocked() {
|
||||
if !t.running {
|
||||
return
|
||||
}
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
|
||||
if t.s != nil {
|
||||
t.s.Close()
|
||||
t.s.Wait()
|
||||
t.s = nil
|
||||
}
|
||||
t.ep = nil
|
||||
t.running = false
|
||||
|
||||
log.Debugf("TCP DNS stack stopped")
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) resetTimerLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
t.timerID++
|
||||
id := t.timerID
|
||||
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Only stop if this timer is still the active one.
|
||||
// A racing InjectPacket may have replaced it.
|
||||
if t.timerID != id {
|
||||
return
|
||||
}
|
||||
t.stopLocked()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
conn := gonet.NewTCPConn(&wq, ep)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Tracef("TCP DNS: close conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Reset idle timer on activity
|
||||
t.mu.Lock()
|
||||
t.resetTimerLocked()
|
||||
t.mu.Unlock()
|
||||
|
||||
localAddr := &net.TCPAddr{
|
||||
IP: id.LocalAddress.AsSlice(),
|
||||
Port: int(id.LocalPort),
|
||||
}
|
||||
remoteAddr := &net.TCPAddr{
|
||||
IP: id.RemoteAddress.AsSlice(),
|
||||
Port: int(id.RemotePort),
|
||||
}
|
||||
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
|
||||
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
msg, err := readTCPDNSMessage(conn)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
writer := &tcpResponseWriter{
|
||||
conn: conn,
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
t.mux.ServeDNS(writer, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
|
||||
type dnsEndpoint struct {
|
||||
dispatcher stack.NetworkDispatcher
|
||||
tunDev tun.Device
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
|
||||
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
|
||||
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
|
||||
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
|
||||
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
|
||||
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
|
||||
func (e *dnsEndpoint) Wait() { /* no async work */ }
|
||||
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
|
||||
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
|
||||
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
|
||||
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
|
||||
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
|
||||
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
|
||||
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
|
||||
|
||||
const tunPacketOffset = 40
|
||||
|
||||
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var written int
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := data.AsSlice()
|
||||
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
|
||||
buf = append(buf, raw...)
|
||||
data.Release()
|
||||
|
||||
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
|
||||
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
|
||||
continue
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
|
||||
type tcpResponseWriter struct {
|
||||
conn *gonet.TCPConn
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) LocalAddr() net.Addr {
|
||||
return w.localAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack: %w", err)
|
||||
}
|
||||
|
||||
// DNS TCP: 2-byte length prefix + message
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
|
||||
if _, err = w.conn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
|
||||
buf := make([]byte, 2+len(data))
|
||||
buf[0] = byte(len(data) >> 8)
|
||||
buf[1] = byte(len(data))
|
||||
copy(buf[2:], data)
|
||||
if _, err := w.conn.Write(buf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *tcpResponseWriter) TsigStatus() error { return nil }
|
||||
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
|
||||
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
|
||||
|
||||
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
|
||||
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
|
||||
// DNS over TCP uses a 2-byte length prefix
|
||||
lenBuf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(conn, lenBuf); err != nil {
|
||||
return nil, fmt.Errorf("read length: %w", err)
|
||||
}
|
||||
|
||||
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
|
||||
if msgLen == 0 || msgLen > 65535 {
|
||||
return nil, fmt.Errorf("invalid message length: %d", msgLen)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, msgLen)
|
||||
if _, err := io.ReadFull(conn, msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
return nil, fmt.Errorf("unpack: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
|
||||
// Supports both IPv4 and IPv6.
|
||||
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
|
||||
if len(pkt) == 0 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcIP, transportOffset := srcIPFromPacket(pkt)
|
||||
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
|
||||
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
|
||||
}
|
||||
|
||||
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
|
||||
switch header.IPVersion(pkt) {
|
||||
case 4:
|
||||
return srcIPv4(pkt)
|
||||
case 6:
|
||||
return srcIPv6(pkt)
|
||||
default:
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
}
|
||||
|
||||
func srcIPv4(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv4MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv4(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, int(hdr.HeaderLength())
|
||||
}
|
||||
|
||||
func srcIPv6(pkt []byte) (netip.Addr, int) {
|
||||
if len(pkt) < header.IPv6MinimumSize {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
hdr := header.IPv6(pkt)
|
||||
src := hdr.SourceAddress()
|
||||
ip, ok := netip.AddrFromSlice(src.AsSlice())
|
||||
if !ok {
|
||||
return netip.Addr{}, 0
|
||||
}
|
||||
return ip, header.IPv6MinimumSize
|
||||
}
|
||||
@@ -41,10 +41,61 @@ const (
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
protoTCP = "tcp"
|
||||
)
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
}
|
||||
|
||||
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
|
||||
func dnsProtocolFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
|
||||
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
|
||||
func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
|
||||
r.protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
@@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(w, r, logger)
|
||||
// Propagate inbound protocol so upstream exchange can use TCP directly
|
||||
// when the request came in over TCP.
|
||||
ctx := u.ctx
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
network := addr.Network()
|
||||
ctx = contextWithDNSProtocol(ctx, network)
|
||||
resutil.SetMeta(w, "protocol", network)
|
||||
}
|
||||
|
||||
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
|
||||
if len(failures) > 0 {
|
||||
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
|
||||
}
|
||||
@@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
@@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
@@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
}()
|
||||
@@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
// manipulating our internal fallthrough signaling mechanism
|
||||
@@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
return int(opt.UDPSize())
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
|
||||
client.UDPSize = uint16(currentMTU - (60 + 8))
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
// Note: the query could be sent out on an interface that is not ours,
|
||||
// but higher MTU settings could break truncation handling.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
client.UDPSize = maxUDPPayload
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
@@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// TODO: if the upstream's truncated UDP response already contains more
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
client.Net = "tcp"
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
@@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
// If request came in over TCP, go straight to TCP upstream
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
clientMaxSize := clientUDPMaxSize(r)
|
||||
|
||||
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
|
||||
// response larger than what we can read over UDP.
|
||||
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
|
||||
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoTCP)
|
||||
if rm.Len() > clientMaxSize {
|
||||
rm.Truncate(clientMaxSize)
|
||||
}
|
||||
|
||||
return rm, nil
|
||||
}
|
||||
|
||||
setUpstreamProtocol(ctx, protoUDP)
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
@@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
upstreamExchangeClient := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) isLocalResolver(upstream string) bool {
|
||||
|
||||
@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSProtocolContext(t *testing.T) {
|
||||
t.Run("roundtrip udp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
|
||||
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("roundtrip tcp", func(t *testing.T) {
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
|
||||
})
|
||||
|
||||
t.Run("missing returns empty", func(t *testing.T) {
|
||||
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContext(t *testing.T) {
|
||||
// Start a local DNS server that responds on TCP only
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Addr: "127.0.0.1:0",
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
tcpServer.Listener = tcpLn
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// With TCP context, should connect directly via TCP without trying UDP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
|
||||
// UDP handler returns a truncated response to trigger TCP retry.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// TCP handler returns the full answer.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.3"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{
|
||||
PacketConn: udpPC,
|
||||
Net: "udp",
|
||||
Handler: udpHandler,
|
||||
}
|
||||
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := udpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("udp server: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
|
||||
assert.False(t, rm.Truncated, "TCP response should not be truncated")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
|
||||
// Start only a TCP server (no UDP). With TCP context it should succeed.
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.2"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
tcpServer := &dns.Server{
|
||||
Listener: tcpLn,
|
||||
Net: "tcp",
|
||||
Handler: tcpHandler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := tcpServer.ActivateAndServe(); err != nil {
|
||||
t.Logf("tcp server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
_ = tcpServer.Shutdown()
|
||||
}()
|
||||
|
||||
upstream := tcpLn.Addr().String()
|
||||
|
||||
// TCP context: should skip UDP entirely and go directly to TCP
|
||||
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
require.NotEmpty(t, rm.Answer)
|
||||
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
|
||||
|
||||
// Without TCP context, trying to reach a TCP-only server via UDP should fail
|
||||
ctx2 := context.Background()
|
||||
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
|
||||
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
|
||||
assert.Error(t, err, "should fail when no UDP server and no TCP context")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1"),
|
||||
})
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() { _ = udpServer.Shutdown() })
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
// When the client advertises a large EDNS0 (4096) and the upstream
|
||||
// truncates, the TCP response should NOT be truncated since the full
|
||||
// answer fits within the client's original buffer.
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Truncated = true
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
// Add enough records to exceed MTU but fit within 4096
|
||||
for i := range 20 {
|
||||
m.Answer = append(m.Answer, &dns.TXT{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
|
||||
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
|
||||
})
|
||||
}
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
t.Logf("write msg: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := udpPC.LocalAddr().String()
|
||||
|
||||
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
|
||||
tcpLn, err := net.Listen("tcp", addr)
|
||||
require.NoError(t, err)
|
||||
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
|
||||
|
||||
go func() { _ = udpServer.ActivateAndServe() }()
|
||||
go func() { _ = tcpServer.ActivateAndServe() }()
|
||||
t.Cleanup(func() {
|
||||
_ = udpServer.Shutdown()
|
||||
_ = tcpServer.Shutdown()
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
client := &dns.Client{Timeout: 2 * time.Second}
|
||||
|
||||
// Client with large buffer: should get all records without truncation
|
||||
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r.SetEdns0(4096, false)
|
||||
|
||||
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm)
|
||||
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
|
||||
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
|
||||
|
||||
// Client with small buffer: should get truncated response
|
||||
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
|
||||
r2.SetEdns0(512, false)
|
||||
|
||||
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rm2)
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
|
||||
return
|
||||
}
|
||||
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
|
||||
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
|
||||
}
|
||||
|
||||
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||
startTime := time.Now()
|
||||
logger := log.WithFields(log.Fields{
|
||||
fields := log.Fields{
|
||||
"request_id": resutil.GenerateRequestID(),
|
||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||
})
|
||||
}
|
||||
if addr := w.RemoteAddr(); addr != nil {
|
||||
fields["client"] = addr.String()
|
||||
}
|
||||
logger := log.WithFields(fields)
|
||||
|
||||
f.handleDNSQuery(logger, w, query, startTime)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -46,6 +47,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
@@ -135,6 +137,12 @@ type EngineConfig struct {
|
||||
|
||||
MTU uint16
|
||||
|
||||
// InspectionCACertPath is a local CA cert for transparent proxy MITM.
|
||||
// Takes priority over management-pushed CA.
|
||||
InspectionCACertPath string
|
||||
// InspectionCAKeyPath is the corresponding private key.
|
||||
InspectionCAKeyPath string
|
||||
|
||||
// for debug bundle generation
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
@@ -210,9 +218,10 @@ type Engine struct {
|
||||
// checks are the client-applied posture checks that need to be evaluated on the client
|
||||
checks []*mgmProto.Checks
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
portForwardManager *portforward.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Sync response persistence (protected by syncRespMux)
|
||||
syncRespMux sync.RWMutex
|
||||
@@ -220,6 +229,10 @@ type Engine struct {
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// transparentProxy is the transparent forward proxy for traffic inspection.
|
||||
transparentProxy *inspect.Proxy
|
||||
udpInspectionHookID string
|
||||
|
||||
// auto-update
|
||||
updateManager *updater.Manager
|
||||
|
||||
@@ -259,26 +272,27 @@ func NewEngine(
|
||||
mobileDep MobileDependency,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
mobileDep: mobileDep,
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -500,7 +514,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
@@ -521,6 +535,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Inject firewall into DNS server now that it's available.
|
||||
// The DNS server is created before the firewall because the route manager
|
||||
// depends on the DNS server, and the firewall depends on the wg interface.
|
||||
e.dnsServer.SetFirewall(e.firewall)
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
if err != nil {
|
||||
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
|
||||
@@ -532,6 +551,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
// conntrack entries from being created before the rules are in place
|
||||
e.setupWGProxyNoTrack()
|
||||
|
||||
// Start after interface is up since port may have been resolved from 0 or changed if occupied
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
|
||||
}()
|
||||
|
||||
// Set the WireGuard interface for rosenpass after interface is up
|
||||
if e.rpManager != nil {
|
||||
e.rpManager.SetInterface(e.wgInterface)
|
||||
@@ -1257,6 +1283,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Transparent proxy
|
||||
e.updateTransparentProxy(networkMap.GetTransparentProxyConfig())
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
@@ -1535,12 +1564,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
Signaler: e.signaler,
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
PortForwardManager: e.portForwardManager,
|
||||
MetricsRecorder: e.clientMetrics,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1679,6 +1709,8 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
e.stopTransparentProxy()
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
@@ -1697,6 +1729,12 @@ func (e *Engine) close() {
|
||||
if e.rpManager != nil {
|
||||
_ = e.rpManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
|
||||
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
|
||||
@@ -1800,7 +1838,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
@@ -1837,6 +1875,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
// IsBlockInbound returns whether inbound connections are blocked.
|
||||
func (e *Engine) IsBlockInbound() bool {
|
||||
return e.config.BlockInbound
|
||||
}
|
||||
|
||||
// GetClientMetrics returns the client metrics
|
||||
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
|
||||
@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
|
||||
571
client/internal/engine_tproxy.go
Normal file
571
client/internal/engine_tproxy.go
Normal file
@@ -0,0 +1,571 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// updateTransparentProxy processes transparent proxy configuration from the network map.
|
||||
func (e *Engine) updateTransparentProxy(cfg *mgmProto.TransparentProxyConfig) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
if cfg == nil {
|
||||
log.Tracef("inspect: config is nil")
|
||||
} else {
|
||||
log.Tracef("inspect: config disabled")
|
||||
}
|
||||
// Only stop if explicitly disabled. Don't stop on nil config to avoid
|
||||
// a gap during policy edits where management briefly pushes empty config.
|
||||
if cfg != nil && !cfg.Enabled {
|
||||
e.stopTransparentProxy()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("inspect: config received: enabled=%v mode=%v default_action=%v rules=%d has_ca=%v",
|
||||
cfg.Enabled, cfg.Mode, cfg.DefaultAction, len(cfg.Rules), len(cfg.CaCertPem) > 0)
|
||||
|
||||
// BlockInbound prevents adding TPROXY rules since kernel TPROXY bypasses ACLs.
|
||||
// The userspace forwarder path still works as it operates within the forwarder hook.
|
||||
if e.config.BlockInbound {
|
||||
log.Warnf("inspect: BlockInbound is set, skipping redirect rules (userspace path still active)")
|
||||
}
|
||||
|
||||
proxyConfig, err := toProxyConfig(cfg)
|
||||
if err != nil {
|
||||
log.Errorf("inspect: parse config: %v", err)
|
||||
e.stopTransparentProxy()
|
||||
return
|
||||
}
|
||||
|
||||
// CA priority: local config > management-pushed > auto-generated self-signed.
|
||||
// Local wins over mgmt to prevent compromised management from injecting a CA.
|
||||
e.resolveInspectionCA(&proxyConfig)
|
||||
|
||||
if e.transparentProxy != nil {
|
||||
// Mode change requires full recreate (envoy lifecycle, listener changes).
|
||||
if proxyConfig.Mode != e.transparentProxy.Mode() {
|
||||
log.Infof("inspect: mode changed to %s, recreating engine", proxyConfig.Mode)
|
||||
e.stopTransparentProxy()
|
||||
} else {
|
||||
e.transparentProxy.UpdateConfig(proxyConfig)
|
||||
e.syncTProxyRules(proxyConfig)
|
||||
e.syncUDPInspectionHook()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if e.wgInterface != nil {
|
||||
proxyConfig.WGNetwork = e.wgInterface.Address().Network
|
||||
proxyConfig.ListenAddr = netip.AddrPortFrom(
|
||||
e.wgInterface.Address().IP.Unmap(),
|
||||
proxyConfig.ListenAddr.Port(),
|
||||
)
|
||||
}
|
||||
|
||||
// Pass local IP checker for SSRF prevention
|
||||
if checker, ok := e.firewall.(inspect.LocalIPChecker); ok {
|
||||
proxyConfig.LocalIPChecker = checker
|
||||
}
|
||||
|
||||
p, err := inspect.New(e.ctx, log.WithField("component", "inspect"), proxyConfig)
|
||||
if err != nil {
|
||||
log.Errorf("inspect: start engine: %v", err)
|
||||
return
|
||||
}
|
||||
e.transparentProxy = p
|
||||
|
||||
e.attachProxyToForwarder(p)
|
||||
e.syncTProxyRules(proxyConfig)
|
||||
e.syncUDPInspectionHook()
|
||||
|
||||
log.Infof("inspect: engine started (mode=%s, rules=%d)", proxyConfig.Mode, len(proxyConfig.Rules))
|
||||
}
|
||||
|
||||
// stopTransparentProxy shuts down the transparent proxy and removes interception.
|
||||
func (e *Engine) stopTransparentProxy() {
|
||||
if e.transparentProxy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.attachProxyToForwarder(nil)
|
||||
e.removeTProxyRule()
|
||||
e.removeUDPInspectionHook()
|
||||
|
||||
if err := e.transparentProxy.Close(); err != nil {
|
||||
log.Debugf("inspect: close engine: %v", err)
|
||||
}
|
||||
e.transparentProxy = nil
|
||||
|
||||
log.Info("inspect: engine stopped")
|
||||
}
|
||||
|
||||
const tproxyRuleID = "tproxy-redirect"
|
||||
|
||||
// syncTProxyRules adds a TPROXY rule via the firewall manager to intercept
|
||||
// matching traffic on the WG interface and redirect it to the proxy socket.
|
||||
func (e *Engine) syncTProxyRules(config inspect.Config) {
|
||||
if e.config.BlockInbound {
|
||||
e.removeTProxyRule()
|
||||
return
|
||||
}
|
||||
|
||||
var listenPort uint16
|
||||
if e.transparentProxy != nil {
|
||||
listenPort = e.transparentProxy.ListenPort()
|
||||
}
|
||||
if listenPort == 0 {
|
||||
e.removeTProxyRule()
|
||||
return
|
||||
}
|
||||
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
dstPorts := make([]uint16, len(config.RedirectPorts))
|
||||
copy(dstPorts, config.RedirectPorts)
|
||||
|
||||
log.Debugf("inspect: syncing redirect rules: listen port %d, redirect ports %v, sources %v",
|
||||
listenPort, dstPorts, config.RedirectSources)
|
||||
|
||||
if err := e.firewall.AddTProxyRule(tproxyRuleID, config.RedirectSources, dstPorts, listenPort); err != nil {
|
||||
log.Errorf("inspect: add redirect rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// removeTProxyRule removes the TPROXY redirect rule.
|
||||
func (e *Engine) removeTProxyRule() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
if err := e.firewall.RemoveTProxyRule(tproxyRuleID); err != nil {
|
||||
log.Debugf("inspect: remove redirect rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// syncUDPInspectionHook registers a UDP packet hook on port 443 for QUIC SNI blocking.
|
||||
// The hook is called by the USP filter for each UDP packet matching the port,
|
||||
// allowing the inspection engine to extract QUIC SNI and block by domain.
|
||||
func (e *Engine) syncUDPInspectionHook() {
|
||||
e.removeUDPInspectionHook()
|
||||
|
||||
if e.firewall == nil || e.transparentProxy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p := e.transparentProxy
|
||||
hookID := e.firewall.AddUDPInspectionHook(443, func(packet []byte) bool {
|
||||
srcIP, dstIP, dstPort, udpPayload, ok := parseUDPPacket(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
src := inspect.SourceInfo{IP: srcIP}
|
||||
dst := netip.AddrPortFrom(dstIP, dstPort)
|
||||
action := p.HandleUDPPacket(udpPayload, dst, src)
|
||||
return action == inspect.ActionBlock
|
||||
})
|
||||
|
||||
e.udpInspectionHookID = hookID
|
||||
log.Debugf("inspect: registered UDP inspection hook on port 443 (id=%s)", hookID)
|
||||
}
|
||||
|
||||
// removeUDPInspectionHook removes the QUIC inspection hook.
|
||||
func (e *Engine) removeUDPInspectionHook() {
|
||||
if e.udpInspectionHookID == "" || e.firewall == nil {
|
||||
return
|
||||
}
|
||||
e.firewall.RemoveUDPInspectionHook(e.udpInspectionHookID)
|
||||
e.udpInspectionHookID = ""
|
||||
}
|
||||
|
||||
// parseUDPPacket extracts source/destination IP, destination port, and UDP
|
||||
// payload from a raw IP packet. Supports both IPv4 and IPv6.
|
||||
func parseUDPPacket(packet []byte) (srcIP, dstIP netip.Addr, dstPort uint16, payload []byte, ok bool) {
|
||||
if len(packet) < 1 {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
|
||||
var udpOffset int
|
||||
switch version {
|
||||
case 4:
|
||||
if len(packet) < 20 {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
ihl := int(packet[0]&0x0f) * 4
|
||||
if len(packet) < ihl+8 {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
var srcOK, dstOK bool
|
||||
srcIP, srcOK = netip.AddrFromSlice(packet[12:16])
|
||||
dstIP, dstOK = netip.AddrFromSlice(packet[16:20])
|
||||
if !srcOK || !dstOK {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
udpOffset = ihl
|
||||
|
||||
case 6:
|
||||
// IPv6 fixed header is 40 bytes. Next header must be UDP (17).
|
||||
if len(packet) < 48 { // 40 header + 8 UDP
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
nextHeader := packet[6]
|
||||
if nextHeader != 17 { // not UDP (may have extension headers)
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
var srcOK, dstOK bool
|
||||
srcIP, srcOK = netip.AddrFromSlice(packet[8:24])
|
||||
dstIP, dstOK = netip.AddrFromSlice(packet[24:40])
|
||||
if !srcOK || !dstOK {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
udpOffset = 40
|
||||
|
||||
default:
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
|
||||
srcIP = srcIP.Unmap()
|
||||
dstIP = dstIP.Unmap()
|
||||
dstPort = uint16(packet[udpOffset+2])<<8 | uint16(packet[udpOffset+3])
|
||||
payload = packet[udpOffset+8:]
|
||||
|
||||
return srcIP, dstIP, dstPort, payload, true
|
||||
}
|
||||
|
||||
// attachProxyToForwarder sets or clears the proxy on the userspace forwarder.
|
||||
func (e *Engine) attachProxyToForwarder(p *inspect.Proxy) {
|
||||
type forwarderGetter interface {
|
||||
GetForwarder() *forwarder.Forwarder
|
||||
}
|
||||
|
||||
if fg, ok := e.firewall.(forwarderGetter); ok {
|
||||
if fwd := fg.GetForwarder(); fwd != nil {
|
||||
fwd.SetProxy(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// toProxyConfig converts a proto TransparentProxyConfig to the inspect.Config type.
|
||||
func toProxyConfig(cfg *mgmProto.TransparentProxyConfig) (inspect.Config, error) {
|
||||
config := inspect.Config{
|
||||
Enabled: cfg.Enabled,
|
||||
DefaultAction: toProxyAction(cfg.DefaultAction),
|
||||
}
|
||||
|
||||
switch cfg.Mode {
|
||||
case mgmProto.TransparentProxyMode_TP_MODE_ENVOY:
|
||||
config.Mode = inspect.ModeEnvoy
|
||||
case mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL:
|
||||
config.Mode = inspect.ModeExternal
|
||||
default:
|
||||
config.Mode = inspect.ModeBuiltin
|
||||
}
|
||||
|
||||
if cfg.ExternalProxyUrl != "" {
|
||||
u, err := url.Parse(cfg.ExternalProxyUrl)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse external proxy URL: %w", err)
|
||||
}
|
||||
config.ExternalURL = u
|
||||
}
|
||||
|
||||
for _, s := range cfg.RedirectSources {
|
||||
prefix, err := netip.ParsePrefix(s)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse redirect source %q: %w", s, err)
|
||||
}
|
||||
config.RedirectSources = append(config.RedirectSources, prefix)
|
||||
}
|
||||
|
||||
for _, p := range cfg.RedirectPorts {
|
||||
config.RedirectPorts = append(config.RedirectPorts, uint16(p))
|
||||
}
|
||||
|
||||
// TPROXY listen port: fixed default, overridable via env var.
|
||||
if config.Mode == inspect.ModeBuiltin {
|
||||
port := uint16(inspect.DefaultTProxyPort)
|
||||
if v := os.Getenv("NB_TPROXY_PORT"); v != "" {
|
||||
if p, err := strconv.ParseUint(v, 10, 16); err == nil {
|
||||
port = uint16(p)
|
||||
} else {
|
||||
log.Warnf("invalid NB_TPROXY_PORT %q, using default %d", v, inspect.DefaultTProxyPort)
|
||||
}
|
||||
}
|
||||
config.ListenAddr = netip.AddrPortFrom(netip.IPv4Unspecified(), port)
|
||||
}
|
||||
|
||||
for _, r := range cfg.Rules {
|
||||
rule, err := toProxyRule(r)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse rule %q: %w", r.Id, err)
|
||||
}
|
||||
config.Rules = append(config.Rules, rule)
|
||||
}
|
||||
|
||||
if cfg.Icap != nil {
|
||||
icapCfg, err := toICAPConfig(cfg.Icap)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse ICAP config: %w", err)
|
||||
}
|
||||
config.ICAP = icapCfg
|
||||
}
|
||||
|
||||
if len(cfg.CaCertPem) > 0 && len(cfg.CaKeyPem) > 0 {
|
||||
tlsCfg, err := parseTLSConfig(cfg.CaCertPem, cfg.CaKeyPem)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse TLS config: %w", err)
|
||||
}
|
||||
config.TLS = tlsCfg
|
||||
}
|
||||
|
||||
if config.Mode == inspect.ModeEnvoy {
|
||||
envCfg := &inspect.EnvoyConfig{
|
||||
BinaryPath: cfg.EnvoyBinaryPath,
|
||||
AdminPort: uint16(cfg.EnvoyAdminPort),
|
||||
}
|
||||
if cfg.EnvoySnippets != nil {
|
||||
envCfg.Snippets = &inspect.EnvoySnippets{
|
||||
HTTPFilters: cfg.EnvoySnippets.HttpFilters,
|
||||
NetworkFilters: cfg.EnvoySnippets.NetworkFilters,
|
||||
Clusters: cfg.EnvoySnippets.Clusters,
|
||||
}
|
||||
}
|
||||
config.Envoy = envCfg
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func toProxyRule(r *mgmProto.TransparentProxyRule) (inspect.Rule, error) {
|
||||
rule := inspect.Rule{
|
||||
ID: id.RuleID(r.Id),
|
||||
Action: toProxyAction(r.Action),
|
||||
Priority: int(r.Priority),
|
||||
}
|
||||
|
||||
for _, d := range r.Domains {
|
||||
dom, err := domain.FromString(d)
|
||||
if err != nil {
|
||||
return inspect.Rule{}, fmt.Errorf("parse domain %q: %w", d, err)
|
||||
}
|
||||
rule.Domains = append(rule.Domains, dom)
|
||||
}
|
||||
|
||||
for _, n := range r.Networks {
|
||||
prefix, err := netip.ParsePrefix(n)
|
||||
if err != nil {
|
||||
return inspect.Rule{}, fmt.Errorf("parse network %q: %w", n, err)
|
||||
}
|
||||
rule.Networks = append(rule.Networks, prefix)
|
||||
}
|
||||
|
||||
for _, p := range r.Ports {
|
||||
rule.Ports = append(rule.Ports, uint16(p))
|
||||
}
|
||||
|
||||
for _, proto := range r.Protocols {
|
||||
rule.Protocols = append(rule.Protocols, toProxyProtoType(proto))
|
||||
}
|
||||
|
||||
rule.Paths = r.Paths
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func toProxyProtoType(p mgmProto.TransparentProxyProtocol) inspect.ProtoType {
|
||||
switch p {
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTP:
|
||||
return inspect.ProtoHTTP
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTPS:
|
||||
return inspect.ProtoHTTPS
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_H2:
|
||||
return inspect.ProtoH2
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_H3:
|
||||
return inspect.ProtoH3
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_WEBSOCKET:
|
||||
return inspect.ProtoWebSocket
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_OTHER:
|
||||
return inspect.ProtoOther
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func toProxyAction(a mgmProto.TransparentProxyAction) inspect.Action {
|
||||
switch a {
|
||||
case mgmProto.TransparentProxyAction_TP_ACTION_BLOCK:
|
||||
return inspect.ActionBlock
|
||||
case mgmProto.TransparentProxyAction_TP_ACTION_INSPECT:
|
||||
return inspect.ActionInspect
|
||||
default:
|
||||
return inspect.ActionAllow
|
||||
}
|
||||
}
|
||||
|
||||
func toICAPConfig(cfg *mgmProto.TransparentProxyICAPConfig) (*inspect.ICAPConfig, error) {
|
||||
icap := &inspect.ICAPConfig{
|
||||
MaxConnections: int(cfg.MaxConnections),
|
||||
}
|
||||
|
||||
if cfg.ReqmodUrl != "" {
|
||||
u, err := url.Parse(cfg.ReqmodUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP reqmod URL: %w", err)
|
||||
}
|
||||
icap.ReqModURL = u
|
||||
}
|
||||
|
||||
if cfg.RespmodUrl != "" {
|
||||
u, err := url.Parse(cfg.RespmodUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP respmod URL: %w", err)
|
||||
}
|
||||
icap.RespModURL = u
|
||||
}
|
||||
|
||||
return icap, nil
|
||||
}
|
||||
|
||||
func parseTLSConfig(certPEM, keyPEM []byte) (*inspect.TLSConfig, error) {
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("decode CA certificate PEM")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode(keyPEM)
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("decode CA key PEM")
|
||||
}
|
||||
|
||||
key, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
// Try PKCS8 as fallback
|
||||
pkcs8Key, pkcs8Err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
|
||||
if pkcs8Err != nil {
|
||||
return nil, fmt.Errorf("parse CA private key (tried EC and PKCS8): %w", err)
|
||||
}
|
||||
return &inspect.TLSConfig{CA: cert, CAKey: pkcs8Key}, nil
|
||||
}
|
||||
|
||||
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
|
||||
}
|
||||
|
||||
// resolveInspectionCA sets the TLS config on the proxy config using priority:
|
||||
// 1. Local config file CA (InspectionCACertPath/InspectionCAKeyPath)
|
||||
// 2. Management-pushed CA (already parsed in toProxyConfig)
|
||||
// 3. Auto-generated self-signed CA (ephemeral, for testing)
|
||||
// Local always wins to prevent a compromised management server from injecting a CA.
|
||||
func (e *Engine) resolveInspectionCA(config *inspect.Config) {
|
||||
// 1. Local CA from config file or env vars
|
||||
certPath := e.config.InspectionCACertPath
|
||||
keyPath := e.config.InspectionCAKeyPath
|
||||
if certPath == "" {
|
||||
certPath = os.Getenv("NB_INSPECTION_CA_CERT")
|
||||
}
|
||||
if keyPath == "" {
|
||||
keyPath = os.Getenv("NB_INSPECTION_CA_KEY")
|
||||
}
|
||||
if certPath != "" && keyPath != "" {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Errorf("read local inspection CA cert %s: %v", certPath, err)
|
||||
return
|
||||
}
|
||||
keyPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
log.Errorf("read local inspection CA key %s: %v", keyPath, err)
|
||||
return
|
||||
}
|
||||
tlsCfg, err := parseTLSConfig(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
log.Errorf("parse local inspection CA: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("inspect: using local CA from %s", certPath)
|
||||
config.TLS = tlsCfg
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Management-pushed CA (already set by toProxyConfig)
|
||||
if config.TLS != nil {
|
||||
log.Infof("inspect: using management-pushed CA")
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Auto-generate self-signed CA for testing / accept-cert UX
|
||||
tlsCfg, err := generateSelfSignedCA()
|
||||
if err != nil {
|
||||
log.Errorf("generate self-signed inspection CA: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("inspect: using auto-generated self-signed CA (clients will see certificate warnings)")
|
||||
config.TLS = tlsCfg
|
||||
}
|
||||
|
||||
// generateSelfSignedCA creates an ephemeral ECDSA P-256 CA certificate.
|
||||
// Clients will see certificate warnings but can choose to accept.
|
||||
func generateSelfSignedCA() (*inspect.TLSConfig, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate CA key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial: %w", err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"NetBird Transparent Proxy"},
|
||||
CommonName: "NetBird Inspection CA (auto-generated)",
|
||||
},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 0,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create CA certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse generated CA certificate: %w", err)
|
||||
}
|
||||
|
||||
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
|
||||
}
|
||||
279
client/internal/engine_tproxy_test.go
Normal file
279
client/internal/engine_tproxy_test.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestToProxyConfig_Basic(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Mode: mgmProto.TransparentProxyMode_TP_MODE_BUILTIN,
|
||||
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_ALLOW,
|
||||
RedirectSources: []string{
|
||||
"10.0.0.0/24",
|
||||
"192.168.1.0/24",
|
||||
},
|
||||
RedirectPorts: []uint32{80, 443},
|
||||
Rules: []*mgmProto.TransparentProxyRule{
|
||||
{
|
||||
Id: "block-evil",
|
||||
Domains: []string{"*.evil.com", "malware.example.com"},
|
||||
Action: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
|
||||
Priority: 1,
|
||||
},
|
||||
{
|
||||
Id: "inspect-internal",
|
||||
Domains: []string{"*.internal.corp"},
|
||||
Networks: []string{"10.1.0.0/16"},
|
||||
Ports: []uint32{443, 8443},
|
||||
Action: mgmProto.TransparentProxyAction_TP_ACTION_INSPECT,
|
||||
Priority: 10,
|
||||
},
|
||||
},
|
||||
ListenPort: 8443,
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, config.Enabled)
|
||||
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
|
||||
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
|
||||
|
||||
require.Len(t, config.RedirectSources, 2)
|
||||
assert.Equal(t, "10.0.0.0/24", config.RedirectSources[0].String())
|
||||
assert.Equal(t, "192.168.1.0/24", config.RedirectSources[1].String())
|
||||
|
||||
require.Len(t, config.RedirectPorts, 2)
|
||||
assert.Equal(t, uint16(80), config.RedirectPorts[0])
|
||||
assert.Equal(t, uint16(443), config.RedirectPorts[1])
|
||||
|
||||
require.Len(t, config.Rules, 2)
|
||||
|
||||
// Rule 1: block evil domains
|
||||
assert.Equal(t, "block-evil", string(config.Rules[0].ID))
|
||||
assert.Equal(t, inspect.ActionBlock, config.Rules[0].Action)
|
||||
assert.Equal(t, 1, config.Rules[0].Priority)
|
||||
require.Len(t, config.Rules[0].Domains, 2)
|
||||
assert.Equal(t, "*.evil.com", config.Rules[0].Domains[0].PunycodeString())
|
||||
assert.Equal(t, "malware.example.com", config.Rules[0].Domains[1].PunycodeString())
|
||||
|
||||
// Rule 2: inspect internal
|
||||
assert.Equal(t, "inspect-internal", string(config.Rules[1].ID))
|
||||
assert.Equal(t, inspect.ActionInspect, config.Rules[1].Action)
|
||||
assert.Equal(t, 10, config.Rules[1].Priority)
|
||||
require.Len(t, config.Rules[1].Networks, 1)
|
||||
assert.Equal(t, "10.1.0.0/16", config.Rules[1].Networks[0].String())
|
||||
require.Len(t, config.Rules[1].Ports, 2)
|
||||
|
||||
// Listen address
|
||||
assert.True(t, config.ListenAddr.IsValid())
|
||||
assert.Equal(t, uint16(8443), config.ListenAddr.Port())
|
||||
}
|
||||
|
||||
func TestToProxyConfig_ExternalMode(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Mode: mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL,
|
||||
ExternalProxyUrl: "http://proxy.corp:8080",
|
||||
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, inspect.ModeExternal, config.Mode)
|
||||
assert.Equal(t, inspect.ActionBlock, config.DefaultAction)
|
||||
require.NotNil(t, config.ExternalURL)
|
||||
assert.Equal(t, "http", config.ExternalURL.Scheme)
|
||||
assert.Equal(t, "proxy.corp:8080", config.ExternalURL.Host)
|
||||
}
|
||||
|
||||
func TestToProxyConfig_ICAP(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Icap: &mgmProto.TransparentProxyICAPConfig{
|
||||
ReqmodUrl: "icap://icap-server:1344/reqmod",
|
||||
RespmodUrl: "icap://icap-server:1344/respmod",
|
||||
MaxConnections: 16,
|
||||
},
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, config.ICAP)
|
||||
assert.Equal(t, "icap", config.ICAP.ReqModURL.Scheme)
|
||||
assert.Equal(t, "icap-server:1344", config.ICAP.ReqModURL.Host)
|
||||
assert.Equal(t, "/reqmod", config.ICAP.ReqModURL.Path)
|
||||
assert.Equal(t, "/respmod", config.ICAP.RespModURL.Path)
|
||||
assert.Equal(t, 16, config.ICAP.MaxConnections)
|
||||
}
|
||||
|
||||
func TestToProxyConfig_Empty(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, config.Enabled)
|
||||
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
|
||||
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
|
||||
assert.Empty(t, config.RedirectSources)
|
||||
assert.Empty(t, config.RedirectPorts)
|
||||
assert.Empty(t, config.Rules)
|
||||
assert.Nil(t, config.ICAP)
|
||||
assert.Nil(t, config.TLS)
|
||||
assert.False(t, config.ListenAddr.IsValid())
|
||||
}
|
||||
|
||||
func TestToProxyConfig_InvalidSource(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
RedirectSources: []string{"not-a-cidr"},
|
||||
}
|
||||
|
||||
_, err := toProxyConfig(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "parse redirect source")
|
||||
}
|
||||
|
||||
func TestToProxyConfig_InvalidNetwork(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Rules: []*mgmProto.TransparentProxyRule{
|
||||
{
|
||||
Id: "bad",
|
||||
Networks: []string{"not-a-cidr"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := toProxyConfig(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "parse network")
|
||||
}
|
||||
|
||||
func TestToProxyAction(t *testing.T) {
|
||||
assert.Equal(t, inspect.ActionAllow, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_ALLOW))
|
||||
assert.Equal(t, inspect.ActionBlock, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_BLOCK))
|
||||
assert.Equal(t, inspect.ActionInspect, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_INSPECT))
|
||||
// Unknown defaults to allow
|
||||
assert.Equal(t, inspect.ActionAllow, toProxyAction(99))
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv4(t *testing.T) {
|
||||
// Build a minimal IPv4/UDP packet: 20-byte IPv4 header + 8-byte UDP header + payload
|
||||
packet := make([]byte, 20+8+4)
|
||||
|
||||
// IPv4 header: version=4, IHL=5 (20 bytes)
|
||||
packet[0] = 0x45
|
||||
// Protocol = UDP (17)
|
||||
packet[9] = 17
|
||||
// Source IP: 10.0.0.1
|
||||
packet[12], packet[13], packet[14], packet[15] = 10, 0, 0, 1
|
||||
// Dest IP: 192.168.1.1
|
||||
packet[16], packet[17], packet[18], packet[19] = 192, 168, 1, 1
|
||||
// UDP source port: 54321 (0xD431)
|
||||
packet[20] = 0xD4
|
||||
packet[21] = 0x31
|
||||
// UDP dest port: 443 (0x01BB)
|
||||
packet[22] = 0x01
|
||||
packet[23] = 0xBB
|
||||
// Payload
|
||||
packet[28] = 0xDE
|
||||
packet[29] = 0xAD
|
||||
packet[30] = 0xBE
|
||||
packet[31] = 0xEF
|
||||
|
||||
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.1", srcIP.String())
|
||||
assert.Equal(t, "192.168.1.1", dstIP.String())
|
||||
assert.Equal(t, uint16(443), dstPort)
|
||||
assert.Equal(t, []byte{0xDE, 0xAD, 0xBE, 0xEF}, payload)
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv6(t *testing.T) {
|
||||
// Build a minimal IPv6/UDP packet: 40-byte IPv6 header + 8-byte UDP header + payload
|
||||
packet := make([]byte, 40+8+4)
|
||||
|
||||
// Version = 6 (0x60 in high nibble)
|
||||
packet[0] = 0x60
|
||||
// Payload length: 8 (UDP header) + 4 (payload)
|
||||
packet[4] = 0
|
||||
packet[5] = 12
|
||||
// Next header: UDP (17)
|
||||
packet[6] = 17
|
||||
// Source: 2001:db8::1
|
||||
packet[8] = 0x20
|
||||
packet[9] = 0x01
|
||||
packet[10] = 0x0d
|
||||
packet[11] = 0xb8
|
||||
packet[23] = 0x01
|
||||
// Dest: 2001:db8::2
|
||||
packet[24] = 0x20
|
||||
packet[25] = 0x01
|
||||
packet[26] = 0x0d
|
||||
packet[27] = 0xb8
|
||||
packet[39] = 0x02
|
||||
// UDP source port: 54321 (0xD431)
|
||||
packet[40] = 0xD4
|
||||
packet[41] = 0x31
|
||||
// UDP dest port: 443 (0x01BB)
|
||||
packet[42] = 0x01
|
||||
packet[43] = 0xBB
|
||||
// Payload
|
||||
packet[48] = 0xCA
|
||||
packet[49] = 0xFE
|
||||
packet[50] = 0xBA
|
||||
packet[51] = 0xBE
|
||||
|
||||
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "2001:db8::1", srcIP.String())
|
||||
assert.Equal(t, "2001:db8::2", dstIP.String())
|
||||
assert.Equal(t, uint16(443), dstPort)
|
||||
assert.Equal(t, []byte{0xCA, 0xFE, 0xBA, 0xBE}, payload)
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_TooShort(t *testing.T) {
|
||||
_, _, _, _, ok := parseUDPPacket(nil)
|
||||
assert.False(t, ok)
|
||||
|
||||
_, _, _, _, ok = parseUDPPacket([]byte{0x45, 0x00})
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv6ExtensionHeader(t *testing.T) {
|
||||
// IPv6 with next header != UDP should be rejected
|
||||
packet := make([]byte, 48)
|
||||
packet[0] = 0x60
|
||||
packet[6] = 6 // TCP, not UDP
|
||||
_, _, _, _, ok := parseUDPPacket(packet)
|
||||
assert.False(t, ok, "should reject IPv6 packets with non-UDP next header")
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv4MappedIPv6(t *testing.T) {
|
||||
// IPv4 packet with normal addresses should Unmap correctly
|
||||
packet := make([]byte, 28)
|
||||
packet[0] = 0x45
|
||||
packet[9] = 17
|
||||
packet[12], packet[13], packet[14], packet[15] = 127, 0, 0, 1
|
||||
packet[16], packet[17], packet[18], packet[19] = 10, 0, 0, 1
|
||||
packet[22] = 0x01
|
||||
packet[23] = 0xBB
|
||||
|
||||
srcIP, dstIP, _, _, ok := parseUDPPacket(packet)
|
||||
require.True(t, ok)
|
||||
assert.True(t, srcIP.Is4(), "should be plain IPv4, not mapped")
|
||||
assert.True(t, dstIP.Is4(), "should be plain IPv4, not mapped")
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
PortForwardManager *portforward.Manager
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
@@ -87,16 +89,17 @@ type ConnConfig struct {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
Log *log.Entry
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
portForwardManager *portforward.Manager
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
onDisconnected func(remotePeer string)
|
||||
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
|
||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
portForwardManager: services.PortForwardManager,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
metricsRecorder: services.MetricsRecorder,
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -61,6 +62,9 @@ type WorkerICE struct {
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
|
||||
// portForwardAttempted tracks if we've already tried port forwarding this session
|
||||
portForwardAttempted bool
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
|
||||
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
|
||||
}
|
||||
|
||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||
w.portForwardAttempted = false
|
||||
|
||||
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create agent: %w", err)
|
||||
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if candidate.Type() == ice.CandidateTypeServerReflexive {
|
||||
w.injectPortForwardedCandidate(candidate)
|
||||
}
|
||||
}
|
||||
|
||||
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
|
||||
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
|
||||
pfManager := w.conn.portForwardManager
|
||||
if pfManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
mapping := pfManager.GetMapping()
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
if w.portForwardAttempted {
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.portForwardAttempted = true
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
|
||||
if err != nil {
|
||||
w.log.Warnf("create forwarded candidate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
|
||||
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
|
||||
|
||||
go func() {
|
||||
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
|
||||
w.log.Errorf("signal port-forwarded candidate: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
|
||||
// It uses the NAT gateway's external IP with the forwarded port.
|
||||
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
|
||||
var externalIP string
|
||||
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
|
||||
externalIP = mapping.ExternalIP.String()
|
||||
} else {
|
||||
// Fallback to STUN-discovered address if NAT didn't provide external IP
|
||||
externalIP = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Per RFC 8445, the related address for srflx is the base (host candidate address).
|
||||
// If the original srflx has unspecified related address, use its own address as base.
|
||||
relAddr := srflxCandidate.RelatedAddress().Address
|
||||
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
|
||||
relAddr = srflxCandidate.Address()
|
||||
}
|
||||
|
||||
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
|
||||
// over regular srflx during ICE connectivity checks.
|
||||
priority := srflxCandidate.Priority() + 1000
|
||||
|
||||
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: srflxCandidate.NetworkType().String(),
|
||||
Address: externalIP,
|
||||
Port: int(mapping.ExternalPort),
|
||||
Component: srflxCandidate.Component(),
|
||||
Priority: priority,
|
||||
RelAddr: relAddr,
|
||||
RelPort: int(mapping.InternalPort),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create candidate: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range srflxCandidate.Extensions() {
|
||||
if e.Key == ice.ExtensionKeyCandidateID {
|
||||
e.Value = srflxCandidate.ID()
|
||||
}
|
||||
if err := candidate.AddExtension(e); err != nil {
|
||||
return nil, fmt.Errorf("add extension: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return candidate, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||
if !lok || !rok {
|
||||
continue
|
||||
}
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
|
||||
sessionID,
|
||||
local.NetworkType(), local.Type(), local.Address(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||
local.NetworkType(), local.Type(), local.Address(), local.Port(),
|
||||
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
|
||||
stat.CurrentRoundTripTime*1000)
|
||||
}
|
||||
}
|
||||
|
||||
26
client/internal/portforward/env.go
Normal file
26
client/internal/portforward/env.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
||||
)
|
||||
|
||||
func isDisabledByEnv() bool {
|
||||
val := os.Getenv(envDisableNATMapper)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
280
client/internal/portforward/manager.go
Normal file
280
client/internal/portforward/manager.go
Normal file
@@ -0,0 +1,280 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-nat"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMappingTTL = 2 * time.Hour
|
||||
discoveryTimeout = 10 * time.Second
|
||||
mappingDescription = "NetBird"
|
||||
)
|
||||
|
||||
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
|
||||
// allowing for whitespace/newlines between tags from different router firmware.
|
||||
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
|
||||
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
|
||||
// which blocks startup for ~10s when no gateway is present (affects all clients).
|
||||
|
||||
type Manager struct {
|
||||
cancel context.CancelFunc
|
||||
|
||||
mapping *Mapping
|
||||
mappingLock sync.Mutex
|
||||
|
||||
wgPort uint16
|
||||
|
||||
done chan struct{}
|
||||
stopCtx chan context.Context
|
||||
|
||||
// protect exported functions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new port forwarding manager.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
stopCtx: make(chan context.Context, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
|
||||
m.mu.Lock()
|
||||
if m.cancel != nil {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if isDisabledByEnv() {
|
||||
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if wgPort == 0 {
|
||||
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.wgPort = wgPort
|
||||
|
||||
m.done = make(chan struct{})
|
||||
defer close(m.done)
|
||||
|
||||
ctx, m.cancel = context.WithCancel(ctx)
|
||||
m.mu.Unlock()
|
||||
|
||||
gateway, mapping, err := m.setup(ctx)
|
||||
if err != nil {
|
||||
log.Infof("port forwarding setup: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.mappingLock.Lock()
|
||||
m.mapping = mapping
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
m.renewLoop(ctx, gateway, mapping.TTL)
|
||||
|
||||
select {
|
||||
case cleanupCtx := <-m.stopCtx:
|
||||
// block the Start while cleaned up gracefully
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
default:
|
||||
// return Start immediately and cleanup in background
|
||||
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
go func() {
|
||||
defer cleanupCancel()
|
||||
m.cleanup(cleanupCtx, gateway)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// GetMapping returns the current mapping if ready, nil otherwise
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
m.mappingLock.Lock()
|
||||
defer m.mappingLock.Unlock()
|
||||
|
||||
if m.mapping == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapping := *m.mapping
|
||||
return &mapping
|
||||
}
|
||||
|
||||
// GracefullyStop cancels the manager and attempts to delete the port mapping.
|
||||
// After GracefullyStop returns, the manager cannot be restarted.
|
||||
func (m *Manager) GracefullyStop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
|
||||
m.startTearDown(ctx)
|
||||
|
||||
m.cancel()
|
||||
m.cancel = nil
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
||||
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
||||
defer discoverCancel()
|
||||
|
||||
gateway, err := nat.DiscoverGateway(discoverCtx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("discovered NAT gateway: %s", gateway.Type())
|
||||
|
||||
mapping, err := m.createMapping(ctx, gateway)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create port mapping: %w", err)
|
||||
}
|
||||
return gateway, mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ttl := defaultMappingTTL
|
||||
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
if !isPermanentLeaseRequired(err) {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
|
||||
ttl = 0
|
||||
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
externalIP, err := gateway.GetExternalAddress()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get external address: %v", err)
|
||||
// todo return with err?
|
||||
}
|
||||
|
||||
mapping := &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: m.wgPort,
|
||||
ExternalPort: uint16(externalPort),
|
||||
ExternalIP: externalIP,
|
||||
NATType: gateway.Type(),
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
|
||||
m.wgPort, externalPort, gateway.Type(), externalIP)
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
||||
if ttl == 0 {
|
||||
// Permanent mappings don't expire, just wait for cancellation.
|
||||
<-ctx.Done()
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||
log.Warnf("failed to renew port mapping: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add port mapping: %w", err)
|
||||
}
|
||||
|
||||
if uint16(externalPort) != m.mapping.ExternalPort {
|
||||
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
|
||||
m.mappingLock.Lock()
|
||||
m.mapping.ExternalPort = uint16(externalPort)
|
||||
m.mappingLock.Unlock()
|
||||
}
|
||||
|
||||
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
|
||||
m.mappingLock.Lock()
|
||||
mapping := m.mapping
|
||||
m.mapping = nil
|
||||
m.mappingLock.Unlock()
|
||||
|
||||
if mapping == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
|
||||
log.Warnf("delete port mapping on stop: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
|
||||
}
|
||||
|
||||
func (m *Manager) startTearDown(ctx context.Context) {
|
||||
select {
|
||||
case m.stopCtx <- ctx:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
|
||||
func isPermanentLeaseRequired(err error) bool {
|
||||
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
|
||||
}
|
||||
39
client/internal/portforward/manager_js.go
Normal file
39
client/internal/portforward/manager_js.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mapping represents an active NAT port mapping.
|
||||
type Mapping struct {
|
||||
Protocol string
|
||||
InternalPort uint16
|
||||
ExternalPort uint16
|
||||
ExternalIP net.IP
|
||||
NATType string
|
||||
// TTL is the lease duration. Zero means a permanent lease that never expires.
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a stub manager for js/wasm builds.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
|
||||
func (m *Manager) Start(context.Context, uint16) {
|
||||
// no NAT traversal in wasm
|
||||
}
|
||||
|
||||
// GracefullyStop is a no-op on js/wasm.
|
||||
func (m *Manager) GracefullyStop(context.Context) error { return nil }
|
||||
|
||||
// GetMapping always returns nil on js/wasm.
|
||||
func (m *Manager) GetMapping() *Mapping {
|
||||
return nil
|
||||
}
|
||||
201
client/internal/portforward/manager_test.go
Normal file
201
client/internal/portforward/manager_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
//go:build !js
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockNAT struct {
|
||||
natType string
|
||||
deviceAddr net.IP
|
||||
externalAddr net.IP
|
||||
internalAddr net.IP
|
||||
mappings map[int]int
|
||||
addMappingErr error
|
||||
deleteMappingErr error
|
||||
onlyPermanentLeases bool
|
||||
lastTimeout time.Duration
|
||||
}
|
||||
|
||||
func newMockNAT() *mockNAT {
|
||||
return &mockNAT{
|
||||
natType: "Mock-NAT",
|
||||
deviceAddr: net.ParseIP("192.168.1.1"),
|
||||
externalAddr: net.ParseIP("203.0.113.50"),
|
||||
internalAddr: net.ParseIP("192.168.1.100"),
|
||||
mappings: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockNAT) Type() string {
|
||||
return m.natType
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
|
||||
return m.deviceAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
|
||||
return m.externalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
|
||||
return m.internalAddr, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
|
||||
if m.addMappingErr != nil {
|
||||
return 0, m.addMappingErr
|
||||
}
|
||||
if m.onlyPermanentLeases && timeout != 0 {
|
||||
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
|
||||
}
|
||||
externalPort := internalPort
|
||||
m.mappings[internalPort] = externalPort
|
||||
m.lastTimeout = timeout
|
||||
return externalPort, nil
|
||||
}
|
||||
|
||||
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||
if m.deleteMappingErr != nil {
|
||||
return m.deleteMappingErr
|
||||
}
|
||||
delete(m.mappings, internalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManager_CreateMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, "udp", mapping.Protocol)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, uint16(51820), mapping.ExternalPort)
|
||||
assert.Equal(t, "Mock-NAT", mapping.NATType)
|
||||
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
|
||||
assert.Equal(t, defaultMappingTTL, mapping.TTL)
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
|
||||
m := NewManager()
|
||||
assert.Nil(t, m.GetMapping())
|
||||
}
|
||||
|
||||
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
mapping := m.GetMapping()
|
||||
require.NotNil(t, mapping)
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
|
||||
// Mutating the returned copy should not affect the manager's mapping.
|
||||
mapping.ExternalPort = 9999
|
||||
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.mapping = &Mapping{
|
||||
Protocol: "udp",
|
||||
InternalPort: 51820,
|
||||
ExternalPort: 51820,
|
||||
}
|
||||
|
||||
gateway := newMockNAT()
|
||||
// Seed the mock so we can verify deletion.
|
||||
gateway.mappings[51820] = 51820
|
||||
|
||||
m.cleanup(context.Background(), gateway)
|
||||
|
||||
_, exists := gateway.mappings[51820]
|
||||
assert.False(t, exists, "mapping should be deleted from gateway")
|
||||
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
|
||||
}
|
||||
|
||||
func TestManager_Cleanup_NilMapping(t *testing.T) {
|
||||
m := NewManager()
|
||||
gateway := newMockNAT()
|
||||
|
||||
// Should not panic or call gateway.
|
||||
m.cleanup(context.Background(), gateway)
|
||||
}
|
||||
|
||||
|
||||
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.wgPort = 51820
|
||||
|
||||
gateway := newMockNAT()
|
||||
gateway.onlyPermanentLeases = true
|
||||
|
||||
mapping, err := m.createMapping(context.Background(), gateway)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, mapping)
|
||||
|
||||
assert.Equal(t, uint16(51820), mapping.InternalPort)
|
||||
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
|
||||
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
|
||||
}
|
||||
|
||||
func TestIsPermanentLeaseRequired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "UPnP error 725",
|
||||
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped error with 725",
|
||||
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error 725 with newlines in XML",
|
||||
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "bare 725 without XML tag",
|
||||
err: fmt.Errorf("error code 725"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unrelated error",
|
||||
err: fmt.Errorf("connection refused"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
|
||||
// mgmProber is the subset of management client needed for URL migration probes.
|
||||
type mgmProber interface {
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
HealthCheck() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -97,6 +97,9 @@ type ConfigInput struct {
|
||||
LazyConnectionEnabled *bool
|
||||
|
||||
MTU *uint16
|
||||
|
||||
InspectionCACertPath string
|
||||
InspectionCAKeyPath string
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -171,6 +174,13 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// InspectionCACertPath is the path to a PEM CA certificate for transparent proxy MITM.
|
||||
// Local CA takes priority over management-pushed CA.
|
||||
InspectionCACertPath string
|
||||
|
||||
// InspectionCAKeyPath is the path to the PEM CA private key for transparent proxy MITM.
|
||||
InspectionCAKeyPath string
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -603,6 +613,17 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InspectionCACertPath != "" && input.InspectionCACertPath != config.InspectionCACertPath {
|
||||
log.Infof("updating inspection CA cert path to %s", input.InspectionCACertPath)
|
||||
config.InspectionCACertPath = input.InspectionCACertPath
|
||||
updated = true
|
||||
}
|
||||
if input.InspectionCAKeyPath != "" && input.InspectionCAKeyPath != config.InspectionCAKeyPath {
|
||||
log.Infof("updating inspection CA key path to %s", input.InspectionCAKeyPath)
|
||||
config.InspectionCAKeyPath = input.InspectionCAKeyPath
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -777,8 +798,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
if err = client.HealthCheck(); err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -17,12 +17,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type mockMgmProber struct {
|
||||
key wgtypes.Key
|
||||
}
|
||||
type mockMgmProber struct{}
|
||||
|
||||
func (m *mockMgmProber) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
return &m.key, nil
|
||||
func (m *mockMgmProber) HealthCheck() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockMgmProber) Close() error { return nil }
|
||||
@@ -247,11 +245,7 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
origProber := newMgmProber
|
||||
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
|
||||
key, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mockMgmProber{key: key.PublicKey()}, nil
|
||||
return &mockMgmProber{}, nil
|
||||
}
|
||||
t.Cleanup(func() { newMgmProber = origProber })
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ type Manager interface {
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
GetClientRoutes() route.HAMap
|
||||
GetSelectedClientRoutes() route.HAMap
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
@@ -465,6 +466,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
|
||||
return maps.Clone(m.clientRoutes)
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes returns only the currently selected/active client routes,
|
||||
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
|
||||
// if traffic should be routed through the tunnel.
|
||||
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
m.mux.Lock()
|
||||
|
||||
@@ -18,6 +18,7 @@ type MockManager struct {
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
GetSelectedClientRoutesFunc func() route.HAMap
|
||||
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
|
||||
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
if m.GetClientRoutesFunc != nil {
|
||||
return m.GetClientRoutesFunc()
|
||||
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
|
||||
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
|
||||
if m.GetSelectedClientRoutesFunc != nil {
|
||||
return m.GetSelectedClientRoutesFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
|
||||
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
if m.GetClientRoutesWithNetIDFunc != nil {
|
||||
|
||||
@@ -53,7 +53,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
n.currentPrefixes = newNets
|
||||
n.notify()
|
||||
}
|
||||
|
||||
func (n *Notifier) notify() {
|
||||
n.listenerMux.Lock()
|
||||
defer n.listenerMux.Unlock()
|
||||
|
||||
@@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
||||
cfg.WgIface = interfaceName
|
||||
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||
hostDNS := []netip.AddrPort{
|
||||
netip.MustParseAddrPort("9.9.9.9:53"),
|
||||
netip.MustParseAddrPort("149.112.112.112:53"),
|
||||
}
|
||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
|
||||
@@ -1359,6 +1359,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
|
||||
}
|
||||
|
||||
if engine.IsBlockInbound() {
|
||||
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
|
||||
}
|
||||
|
||||
mgr := engine.GetExposeManager()
|
||||
if mgr == nil {
|
||||
return gstatus.Errorf(codes.Internal, "expose manager not available")
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
// registerStates registers all states that need crash recovery cleanup.
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
|
||||
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||
if err != nil {
|
||||
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
}
|
||||
|
||||
if hasCommand {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
if err := serverSession.Run(session.RawCommand()); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
|
||||
@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
sshServer := s.sshServer
|
||||
if sshServer == nil {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := s.sshServer.Close(); err != nil {
|
||||
// Close outside the lock: session handlers need s.mu for unregisterSession.
|
||||
if err := sshServer.Close(); err != nil {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
|
||||
s.mu.Lock()
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.connections)
|
||||
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
}
|
||||
maps.Clear(s.remoteForwardListeners)
|
||||
s.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user