Compare commits

..

3 Commits

Author SHA1 Message Date
Zoltán Papp
85e4afe100 Refactor route manager to use sync.Once for wgInterface access
- Added `toInterface` helper to ensure `wgInterface.ToInterface()` is called once.
- Updated route addition/removal methods to use `toInterface` for efficiency.
2026-02-03 15:26:37 +01:00
Zoltán Papp
a72e0e67e5 Optimize Windows DNS performance with domain batching and batch mode
Implement two-layer optimization to reduce Windows NRPT registry operations:

1. Domain Batching (host_windows.go):
  - Batch up to 50 domains per NRPT rule (Windows undocumented limit)
  - Reduces NRPT rules by ~97% (e.g., 184 domains: 184 rules → 4 rules)
  - Modified addDNSMatchPolicy() to create batched NRPT entries
  - Added comprehensive tests in host_windows_test.go

2. Batch Mode (server.go):
  - Added BeginBatch/EndBatch methods to defer DNS updates
  - Modified RegisterHandler/DeregisterHandler to skip applyHostConfig in batch mode
  - Protected all applyHostConfig() calls with batch mode checks
  - Updated route manager to wrap route operations with batch calls
2026-02-03 00:15:04 +01:00
Zoltán Papp
a33f60e3fd Adds timing measurement to handleSync to help diagnose sync performance issues 2026-02-01 00:12:11 +01:00
281 changed files with 1384 additions and 35600 deletions

View File

@@ -1,6 +0,0 @@
.env
.env.*
*.pem
*.key
*.crt
*.p12

View File

@@ -23,7 +23,7 @@ jobs:
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -39,11 +39,11 @@ jobs:
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name ".git*" | sort)
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -43,5 +43,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -46,5 +46,6 @@ jobs:
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -144,7 +144,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
@@ -204,7 +204,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /client/ui -e /upload-server)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
@@ -261,53 +261,6 @@ jobs:
-exec 'sudo' \
-timeout 10m -p 1 ./relay/... ./shared/relay/...
test_proxy:
name: "Proxy / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Install dependencies
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test -timeout 10m -p 1 ./proxy/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]

View File

@@ -63,7 +63,7 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' })" >> $env:GITHUB_ENV
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"

View File

@@ -20,7 +20,7 @@ jobs:
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
skip: go.mod,go.sum,**/proxy/web/**
skip: go.mod,go.sum
golangci:
strategy:
fail-fast: false

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.1"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
.run
*.iml
dist/
!proxy/web/dist/
bin/
.env
conf.json

View File

@@ -106,26 +106,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -540,55 +520,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -667,18 +598,6 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -756,19 +675,6 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features

View File

@@ -282,9 +282,13 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
}
defer authClient.Close()
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("check login required: %v", err)
needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
}
jwtToken := ""

View File

@@ -31,14 +31,6 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -79,8 +71,6 @@ type Options struct {
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
}
// validateCredentials checks that exactly one credential type is provided
@@ -150,7 +140,6 @@ func New(opts Options) (*Client, error) {
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -170,7 +159,6 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil
}
@@ -192,7 +180,6 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
@@ -202,7 +189,10 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -355,9 +345,14 @@ func (c *Client) NewHTTPClient() *http.Client {
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
@@ -365,7 +360,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
}
}
return c.recorder.GetFullStatus(), nil
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.

View File

@@ -483,12 +483,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -665,32 +660,13 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -952,30 +928,18 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
}
return nil
@@ -1365,89 +1329,65 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
return fmt.Errorf("remove legacy routing rule: %w", err)
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
}
return nil
}
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("prerouting rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
return fmt.Errorf("list rules: %w", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[string(rule.UserData)] = rule
r.rules[string(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
return nil
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1689,34 +1629,20 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
@@ -1831,25 +1757,16 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -18,7 +18,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -720,137 +719,3 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -3,6 +3,12 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -11,7 +17,33 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,9 +1,12 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -23,7 +26,33 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.resetState()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
if !isWindowsFirewallReachable() {
return nil

View File

@@ -115,17 +115,6 @@ func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// IsSupersededBy returns true if this connection should be replaced by a new one
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
// connections are superseded by a pure SYN (a new connection attempt for the same
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
if t.tombstone.Load() {
return true
}
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
@@ -180,7 +169,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists && !conn.IsSupersededBy(flags) {
if exists {
t.updateState(key, conn, flags, direction, size)
return key, uint16(conn.DNATOrigPort.Load()), true
}
@@ -252,7 +241,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsSupersededBy(flags) {
if !exists || conn.IsTombstone() {
return false
}

View File

@@ -485,261 +485,6 @@ func TestTCPAbnormalSequences(t *testing.T) {
})
}
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
// updateIfExists treats tombstoned entries as live, causing track() to skip
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
// because the entry is tombstoned, and the response packet gets dropped by ACL.
func TestTCPPortReuseTombstone(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and gracefully close a connection (server-initiated close)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Server sends FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
// Client sends FIN-ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Server sends final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Connection should be tombstoned
conn := tracker.connections[key]
require.NotNil(t, conn, "old connection should still be in map")
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
// Now reuse the same port for a new connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// The old tombstoned entry should be replaced with a new one
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState())
// SYN-ACK for the new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
// Data transfer should work
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
require.True(t, valid, "data should be allowed on new connection")
})
t.Run("Outbound port reuse after RST", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and RST a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
// Reuse the same port
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynSent, newConn.GetState())
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
})
t.Run("Inbound port reuse after close", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := srcIP
serverIP := dstIP
clientPort := srcPort
serverPort := dstPort
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
// Inbound connection: client SYN → server SYN-ACK → client ACK
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateEstablished, conn.GetState())
// Server-initiated close to reach Closed/tombstoned:
// Server FIN (opposite dir) → CloseWait
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Client FIN-ACK (same dir as conn) → LastAck
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Server final ACK (opposite dir) → Closed → tombstoned
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
require.True(t, conn.IsTombstone())
// New inbound connection on same ports
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
newConn := tracker.connections[key]
require.NotNil(t, newConn)
require.False(t, newConn.IsTombstone())
require.Equal(t, TCPStateSynReceived, newConn.GetState())
// Complete handshake: server SYN-ACK, then client ACK
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
conn := tracker.connections[key]
require.True(t, conn.IsTombstone())
// Late ACK should be rejected (tombstoned)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
// Late outbound ACK should not create a new connection (not a SYN)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
})
}
func TestTCPPortReuseTimeWait(t *testing.T) {
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Active close: client (outbound initiator) sends FIN first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Server ACKs the FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Server sends its own FIN
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
// New outbound SYN on the same port (port reuse during TIME-WAIT)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
newConn := tracker.connections[key]
require.NotNil(t, newConn, "new connection should exist")
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
// SYN-ACK for new connection should be valid
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK for new connection should be accepted")
require.Equal(t, TCPStateEstablished, newConn.GetState())
})
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish outbound connection and close via active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
// so the filter falls through to ACL check + TrackInbound (which creates
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
// Simulate what the filter does next: TrackInbound via the normal path
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
newConn := tracker.connections[invertedKey]
require.NotNil(t, newConn, "new inbound connection should be tracked")
require.Equal(t, TCPStateSynReceived, newConn.GetState())
require.False(t, newConn.IsTombstone())
})
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Establish and active close → TIME-WAIT
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Late ACK retransmits during TIME-WAIT should still be accepted
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond

View File

@@ -1,7 +1,6 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -13,13 +12,11 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -27,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -93,7 +89,6 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -234,7 +229,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -486,15 +480,11 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: string(ruleKey),
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -509,7 +499,6 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -526,20 +515,15 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
return r.id == ruleID
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
return fmt.Errorf("route rule not found: %s", ruleID)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -586,40 +570,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {

View File

@@ -1,376 +0,0 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,158 +263,6 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -5,8 +5,6 @@ import (
"context"
"fmt"
"io"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -18,18 +16,9 @@ const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
defaultLogChanSize = 1000
logChannelSize = 1000
)
func getLogChannelSize() int {
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultLogChanSize
}
type Level uint32
const (
@@ -80,7 +69,7 @@ type Logger struct {
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, getLogChannelSize()),
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {

View File

@@ -29,9 +29,8 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
mutex sync.RWMutex
}
// newDeviceFilter constructor function
@@ -41,20 +40,6 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -82,9 +82,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
@@ -229,10 +228,6 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if nbnetstack.IsEnabled() {
return errors.FormatErrorOrNil(result)
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
if err := w.Destroy(); err != nil {

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
}()
return t.tundev, tunNet, nil
return nsTunDev, tunNet, nil
}
func (t *NetStackTun) Close() error {

View File

@@ -189,212 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -245,7 +244,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
c.statusRecorder.UpdateLocalPeerState(localPeerState)

View File

@@ -42,6 +42,10 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8
// NRPT rules cannot handle more than 50 domains per rule.
// This is an undocumented Windows limitation.
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
@@ -239,23 +243,32 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain}
// NRPT rules have an undocumented restriction: each rule can only handle up to 50 domains.
// We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
}
batchDomains := domains[i:end]
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
}
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex, err)
}
}
log.Debugf("added NRPT entry for domain: %s", domain)
log.Debugf("added NRPT rule %d with %d domains", ruleIndex, len(batchDomains))
ruleIndex++
}
if r.gpo {
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return len(domains), nil
log.Infof("added %d NRPT rules for %d domains. Domain list: %s", ruleIndex, len(domains), domains)
return ruleIndex, nil
}
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -97,6 +97,107 @@ func registryKeyExists(path string) (bool, error) {
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
// Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
_ = cfg.removeDNSMatchPolicies()
}
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules
// with a maximum of 50 domains per rule (Windows limitation).
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -84,3 +84,13 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}

View File

@@ -6,9 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
@@ -29,8 +27,6 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -45,6 +41,8 @@ type IosDnsManager interface {
type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
Initialize() error
Stop()
DnsIP() netip.Addr
@@ -87,6 +85,7 @@ type DefaultServer struct {
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver
@@ -234,7 +233,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
s.applyHostConfig()
if !s.batchMode {
s.applyHostConfig()
}
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -263,6 +264,28 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone)
}
}
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig()
}
@@ -443,17 +466,6 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)
@@ -523,7 +535,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false
}
s.applyHostConfig()
if !s.batchMode {
s.applyHostConfig()
}
s.shutdownWg.Add(1)
go func() {
@@ -887,7 +901,9 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
s.applyHostConfig()
if !s.batchMode {
s.applyHostConfig()
}
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
@@ -922,7 +938,9 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
}
s.applyHostConfig()
if !s.batchMode {
s.applyHostConfig()
}
s.updateNSState(nsGroup, nil, true)
}

View File

@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
t.Errorf("mismatch dns instances")
}
}

View File

@@ -190,75 +190,50 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 {
return
return nil
}
question := query.Question[0]
qname := strings.ToLower(question.Name)
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
f.writeResponse(logger, w, resp, qname, startTime)
return
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(qname, question.Qtype, result.IPs)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
return resp
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -268,7 +243,30 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -278,7 +276,18 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
f.handleDNSQuery(logger, w, query, startTime)
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -325,7 +334,6 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
startTime time.Time,
) {
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
@@ -335,7 +343,9 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
@@ -345,7 +355,9 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
return
}
@@ -353,7 +365,9 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
f.writeResponse(logger, w, resp, domain, startTime)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
}
@@ -361,12 +375,15 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
}
f.writeResponse(logger, w, resp, domain, startTime)
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,9 +318,8 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -330,9 +329,10 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
@@ -466,16 +466,14 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else {
require.NotNil(t, resp, "Expected response")
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
@@ -530,10 +528,9 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -608,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -678,8 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -687,13 +683,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -719,8 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -732,13 +727,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2)
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -789,9 +784,8 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -903,15 +897,26 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, resp.Answer, "Response should have no answer records")
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
@@ -926,8 +931,15 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{}
// Don't set any question
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -28,7 +28,6 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
@@ -544,12 +543,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() {
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart()
} else if err != nil {
@@ -1023,7 +1021,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn()
e.statusRecorder.UpdateLocalPeerState(state)
@@ -1924,7 +1922,7 @@ func (e *Engine) triggerClientRestart() {
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
@@ -95,10 +94,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
if netstack.IsEnabled() {
return nil
}
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
@@ -221,10 +216,6 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
if netstack.IsEnabled() {
return
}
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {

View File

@@ -11,7 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -75,13 +74,12 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is used on Windows, JS, and netstack platforms:
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()

View File

@@ -2,7 +2,6 @@ package ice
import (
"context"
"fmt"
"sync"
"time"
@@ -33,6 +32,24 @@ type ThreadSafeAgent struct {
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -76,41 +93,9 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {

View File

@@ -107,10 +107,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if w.agent != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
sessionID, err := NewICESessionID()

View File

@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultAdminURL)
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err

View File

@@ -346,6 +346,13 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
defer m.dnsServer.EndBatch()
}
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))

View File

@@ -4,17 +4,16 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
@@ -25,51 +24,42 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
if flags&unix.RTF_CLONING != 0 {
if flags&syscall.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C")
}
if flags&unix.RTF_WASCLONED != 0 {
if flags&syscall.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W")
}
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -4,18 +4,17 @@ package systemops
import (
"strings"
"golang.org/x/sys/unix"
"syscall"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&unix.RTF_UP == 0 {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
@@ -26,46 +25,37 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&unix.RTF_UP != 0 {
if flags&syscall.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&unix.RTF_GATEWAY != 0 {
if flags&syscall.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&unix.RTF_HOST != 0 {
if flags&syscall.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&unix.RTF_REJECT != 0 {
if flags&syscall.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&unix.RTF_DYNAMIC != 0 {
if flags&syscall.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&unix.RTF_MODIFIED != 0 {
if flags&syscall.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&unix.RTF_STATIC != 0 {
if flags&syscall.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&unix.RTF_LLINFO != 0 {
if flags&syscall.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&unix.RTF_LOCAL != 0 {
if flags&syscall.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&unix.RTF_BLACKHOLE != 0 {
if flags&syscall.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -1,5 +0,0 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server

View File

@@ -1,715 +0,0 @@
package cmd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"path"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
// CombinedConfig is the root configuration for the combined server.
// The combined server is primarily a Management server with optional embedded
// Signal, Relay, and STUN services.
//
// Architecture:
// - Management: Always runs locally (this IS the management server)
// - Signal: Runs locally by default; disabled if server.signalUri is set
// - Relay: Runs locally by default; disabled if server.relays is set
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
//
// All user-facing settings are under "server". The relay/signal/management
// fields are internal and populated automatically from server settings.
type CombinedConfig struct {
Server ServerConfig `yaml:"server"`
// Internal configs - populated from Server settings, not user-configurable
Relay RelayConfig `yaml:"-"`
Signal SignalConfig `yaml:"-"`
Management ManagementConfig `yaml:"-"`
}
// ServerConfig contains server-wide settings
// In simplified mode, this contains all configuration
type ServerConfig struct {
ListenAddress string `yaml:"listenAddress"`
MetricsPort int `yaml:"metricsPort"`
HealthcheckAddress string `yaml:"healthcheckAddress"`
LogLevel string `yaml:"logLevel"`
LogFile string `yaml:"logFile"`
TLS TLSConfig `yaml:"tls"`
// Simplified config fields (used when relay/signal/management sections are omitted)
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
DataDir string `yaml:"dataDir"` // Data directory for all services
// External service overrides (simplified mode)
// When these are set, the corresponding local service is NOT started
// and these values are used for client configuration instead
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
// Management settings (simplified mode)
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// TLSConfig contains TLS/HTTPS settings
type TLSConfig struct {
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
}
// LetsEncryptConfig contains Let's Encrypt settings
type LetsEncryptConfig struct {
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"dataDir"`
Domains []string `yaml:"domains"`
Email string `yaml:"email"`
AWSRoute53 bool `yaml:"awsRoute53"`
}
// RelayConfig contains relay service settings
type RelayConfig struct {
Enabled bool `yaml:"enabled"`
ExposedAddress string `yaml:"exposedAddress"`
AuthSecret string `yaml:"authSecret"`
LogLevel string `yaml:"logLevel"`
Stun StunConfig `yaml:"stun"`
}
// StunConfig contains embedded STUN service settings
type StunConfig struct {
Enabled bool `yaml:"enabled"`
Ports []int `yaml:"ports"`
LogLevel string `yaml:"logLevel"`
}
// SignalConfig contains signal service settings
type SignalConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
}
// ManagementConfig contains management service settings
type ManagementConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
DataDir string `yaml:"dataDir"`
DnsDomain string `yaml:"dnsDomain"`
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
Auth AuthConfig `yaml:"auth"`
Stuns []HostConfig `yaml:"stuns"`
Relays RelaysConfig `yaml:"relays"`
SignalURI string `yaml:"signalUri"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
}
// AuthStorageConfig contains auth storage settings
type AuthStorageConfig struct {
Type string `yaml:"type"`
File string `yaml:"file"`
}
// AuthOwnerConfig contains initial admin user settings
type AuthOwnerConfig struct {
Email string `yaml:"email"`
Password string `yaml:"password"`
}
// HostConfig represents a STUN/TURN/Signal host
type HostConfig struct {
URI string `yaml:"uri"`
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// RelaysConfig contains external relay server settings for clients
type RelaysConfig struct {
Addresses []string `yaml:"addresses"`
CredentialsTTL string `yaml:"credentialsTTL"`
Secret string `yaml:"secret"`
}
// StoreConfig contains database settings
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
}
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
}
// DefaultConfig returns a CombinedConfig with default values
func DefaultConfig() *CombinedConfig {
return &CombinedConfig{
Server: ServerConfig{
ListenAddress: ":443",
MetricsPort: 9090,
HealthcheckAddress: ":9000",
LogLevel: "info",
LogFile: "console",
StunPorts: []int{3478},
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Store: StoreConfig{
Engine: "sqlite",
},
},
Relay: RelayConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
Stun: StunConfig{
Enabled: false,
Ports: []int{3478},
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
},
Signal: SignalConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
Management: ManagementConfig{
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Relays: RelaysConfig{
CredentialsTTL: "12h",
},
Store: StoreConfig{
Engine: "sqlite",
},
},
}
}
// hasRequiredSettings returns true if the configuration has the required server settings
func (c *CombinedConfig) hasRequiredSettings() bool {
return c.Server.ExposedAddress != ""
}
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
// Returns: protocol ("https" or "http"), hostname only, and host:port
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
// Default to https if no protocol specified
protocol = "https"
hostPort = exposedAddress
// Check for protocol prefix
if strings.HasPrefix(exposedAddress, "https://") {
protocol = "https"
hostPort = strings.TrimPrefix(exposedAddress, "https://")
} else if strings.HasPrefix(exposedAddress, "http://") {
protocol = "http"
hostPort = strings.TrimPrefix(exposedAddress, "http://")
}
// Extract hostname (without port)
hostname = hostPort
if host, _, err := net.SplitHostPort(hostPort); err == nil {
hostname = host
}
return protocol, hostname, hostPort
}
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
// overrides are configured (server.signalUri, server.relays, server.stuns).
func (c *CombinedConfig) ApplySimplifiedDefaults() {
if !c.hasRequiredSettings() {
return
}
// Parse exposed address to extract protocol and hostname
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
// Check for external service overrides
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
hasExternalSignal := c.Server.SignalURI != ""
hasExternalStuns := len(c.Server.Stuns) > 0
// Default stunPorts to [3478] if not specified and no external STUN
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
c.Server.StunPorts = []int{3478}
}
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
c.applySignalDefaults(hasExternalSignal)
c.applyManagementDefaults(exposedHost)
// Auto-configure client settings (stuns, relays, signalUri)
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
return
}
c.Relay.Enabled = true
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
c.Relay.AuthSecret = c.Server.AuthSecret
if c.Relay.LogLevel == "" {
c.Relay.LogLevel = c.Server.LogLevel
}
// Enable local STUN only if no external STUN servers and stunPorts are configured
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
c.Relay.Stun.Enabled = true
c.Relay.Stun.Ports = c.Server.StunPorts
if c.Relay.Stun.LogLevel == "" {
c.Relay.Stun.LogLevel = c.Server.LogLevel
}
}
}
// applySignalDefaults configures the signal service if no external signal is configured.
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
if hasExternalSignal {
return
}
c.Signal.Enabled = true
if c.Signal.LogLevel == "" {
c.Signal.LogLevel = c.Server.LogLevel
}
}
// applyManagementDefaults configures the management service (always enabled).
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
c.Management.Enabled = true
if c.Management.LogLevel == "" {
c.Management.LogLevel = c.Server.LogLevel
}
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
c.Management.DnsDomain = exposedHost
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
// Copy auth config from server if management auth issuer is not set
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
c.Management.Auth = c.Server.Auth
}
// Copy store config from server if not set
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" {
c.Management.Store = c.Server.Store
}
}
// Copy reverse proxy config from server
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
c.Management.ReverseProxy = c.Server.ReverseProxy
}
}
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
// External overrides from server config take precedence over auto-generated values
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
// Determine relay protocol from exposed protocol
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
// Configure STUN servers for clients
if hasExternalStuns {
// Use external STUN servers from server config
c.Management.Stuns = c.Server.Stuns
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
})
}
}
// Configure relay for clients
if hasExternalRelay {
// Use external relay config from server
c.Management.Relays = c.Server.Relays
} else if len(c.Management.Relays.Addresses) == 0 {
// Auto-configure local relay
c.Management.Relays.Addresses = []string{
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
}
}
if c.Management.Relays.Secret == "" {
c.Management.Relays.Secret = c.Server.AuthSecret
}
if c.Management.Relays.CredentialsTTL == "" {
c.Management.Relays.CredentialsTTL = "12h"
}
// Configure signal for clients
if hasExternalSignal {
// Use external signal URI from server config
c.Management.SignalURI = c.Server.SignalURI
} else if c.Management.SignalURI == "" {
// Auto-configure local signal
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
}
}
// LoadConfig loads configuration from a YAML file
func LoadConfig(configPath string) (*CombinedConfig, error) {
cfg := DefaultConfig()
if configPath == "" {
return cfg, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Populate internal configs from server settings
cfg.ApplySimplifiedDefaults()
return cfg, nil
}
// Validate validates the configuration
func (c *CombinedConfig) Validate() error {
if c.Server.ExposedAddress == "" {
return fmt.Errorf("server.exposedAddress is required")
}
if c.Server.DataDir == "" {
return fmt.Errorf("server.dataDir is required")
}
// Validate STUN ports
seen := make(map[int]bool)
for _, port := range c.Server.StunPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
}
seen[port] = true
}
// authSecret is required only if running local relay (no external relay configured)
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
if !hasExternalRelay && c.Server.AuthSecret == "" {
return fmt.Errorf("server.authSecret is required when running local relay")
}
return nil
}
// HasTLSCert returns true if TLS certificate files are configured
func (c *CombinedConfig) HasTLSCert() bool {
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
}
// HasLetsEncrypt returns true if Let's Encrypt is configured
func (c *CombinedConfig) HasLetsEncrypt() bool {
return c.Server.TLS.LetsEncrypt.Enabled &&
c.Server.TLS.LetsEncrypt.DataDir != "" &&
len(c.Server.TLS.LetsEncrypt.Domains) > 0
}
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
switch strings.ToLower(proto) {
case "udp":
return nbconfig.UDP, true
case "dtls":
return nbconfig.DTLS, true
case "tcp":
return nbconfig.TCP, true
case "http":
return nbconfig.HTTP, true
case "https":
return nbconfig.HTTPS, true
default:
return "", false
}
}
// parseStunProtocol determines protocol for STUN/TURN servers.
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
// Explicit proto overrides URI scheme. Defaults to UDP.
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
if proto != "" {
if p, ok := parseExplicitProtocol(proto); ok {
return p
}
}
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "stuns:"):
return nbconfig.DTLS
case strings.HasPrefix(uri, "turns:"):
return nbconfig.DTLS
default:
// stun:, turn:, or no scheme - default to UDP
return nbconfig.UDP
}
}
// parseSignalProtocol determines protocol for Signal servers.
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
func parseSignalProtocol(uri string) nbconfig.Protocol {
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "http://"):
return nbconfig.HTTP
default:
// https:// or no scheme - default to HTTPS
return nbconfig.HTTPS
}
}
// stripSignalProtocol removes the protocol prefix from a signal URI.
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
func stripSignalProtocol(uri string) string {
uri = strings.TrimPrefix(uri, "https://")
uri = strings.TrimPrefix(uri, "http://")
return uri
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
// Build STUN hosts
var stuns []*nbconfig.Host
for _, s := range mgmt.Stuns {
stuns = append(stuns, &nbconfig.Host{
URI: s.URI,
Proto: parseStunProtocol(s.URI, s.Proto),
Username: s.Username,
Password: s.Password,
})
}
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
}
}
// Build signal config
var signalConfig *nbconfig.Host
if mgmt.SignalURI != "" {
signalConfig = &nbconfig.Host{
URI: stripSignalProtocol(mgmt.SignalURI),
Proto: parseSignalProtocol(mgmt.SignalURI),
}
}
// Build store config
storeConfig := nbconfig.StoreConfig{
Engine: types.Engine(mgmt.Store.Engine),
}
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
}
}
for _, p := range mgmt.ReverseProxy.TrustedPeers {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
}
}
// Build HTTP config (required, even if empty)
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
}
// Set HTTP config fields for embedded IDP
httpConfig.AuthIssuer = mgmt.Auth.Issuer
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
return &nbconfig.Config{
Stuns: stuns,
Relay: relayConfig,
Signal: signalConfig,
Datadir: mgmt.DataDir,
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
HttpConfig: httpConfig,
StoreConfig: storeConfig,
ReverseProxy: reverseProxy,
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
EmbeddedIdP: embeddedIdP,
}, nil
}
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
// Embedded IdP requires single account mode
if disableSingleAccMode {
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
}
// Set LocalAddress for embedded IdP, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
// EnsureEncryptionKey generates an encryption key if not set.
// Unlike management server, we don't write back to the config file.
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
key, err := crypt.GenerateKey()
if err != nil {
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
}
cfg.DataStoreEncryptionKey = key
keyPreview := key[:8] + "..."
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
return nil
}
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
if cfg.Relay != nil {
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
}
}

View File

@@ -1,33 +0,0 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -1,711 +0,0 @@
package cmd
import (
"context"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/coder/websocket"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/metric"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
)
var (
configPath string
config *CombinedConfig
rootCmd = &cobra.Command{
Use: "combined",
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
Long: `Combined Netbird server for self-hosted deployments.
All services (Management, Signal, Relay) are multiplexed on a single port.
Optional STUN server runs on separate UDP ports.
Configuration is loaded from a YAML file specified with --config.`,
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
}
func Execute() error {
return rootCmd.Execute()
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, _ []string) error {
if err := initializeConfig(); err != nil {
return err
}
// Management is required as the base server when signal or relay are enabled
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
}
servers, err := createAllServers(cmd.Context(), config)
if err != nil {
return err
}
// Register services with management's gRPC server using AfterInit hook
setupServerHooks(servers, config)
// Start management server (this also starts the HTTP listener)
if servers.mgmtSrv != nil {
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
cleanupSTUNListeners(servers.stunListeners)
return fmt.Errorf("failed to start management server: %w", err)
}
}
// Start all other servers
wg := sync.WaitGroup{}
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
wg.Wait()
return err
}
// initializeConfig loads and validates the configuration, then initializes logging.
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if err := config.Validate(); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
log.Infof("Starting combined NetBird server")
logConfig(config)
logEnvVars()
return nil
}
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv *mgmtServer.BaseServer
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
stunListeners []*net.UDPConn
metricsServer *sharedMetrics.Metrics
}
// createAllServers creates all server instances based on configuration.
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
if err != nil {
return nil, fmt.Errorf("failed to create metrics server: %w", err)
}
servers := &serverInstances{
metricsServer: metricsServer,
}
_, tlsSupport, err := handleTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
}
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
return nil, err
}
if err := servers.createManagementServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createSignalServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createHealthcheckServer(cfg); err != nil {
return nil, err
}
return servers, nil
}
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
if !cfg.Relay.Enabled {
return nil
}
var err error
s.stunListeners, err = createSTUNListeners(cfg)
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
relayCfg := relayServer.Config{
Meter: s.metricsServer.Meter,
ExposedAddress: cfg.Relay.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
if err != nil {
return err
}
log.Infof("Relay server created")
if len(s.stunListeners) > 0 {
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
}
return nil
}
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Management.Enabled {
return nil
}
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("failed to create management config: %w", err)
}
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
if portErr != nil {
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
}
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to ensure encryption key: %w", err)
}
LogConfigInfo(mgmtConfig)
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management server: %w", err)
}
// Inject externally-managed AppMetrics so management uses the shared metrics server
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management app metrics: %w", err)
}
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
log.Infof("Management server created")
return nil
}
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Signal.Enabled {
return nil
}
var err error
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create signal server: %w", err)
}
log.Infof("Signal server created")
return nil
}
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
hCfg := healthcheck.Config{
ListenAddress: cfg.Server.HealthcheckAddress,
ServiceChecker: s.relaySrv,
}
var err error
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
return err
}
// setupServerHooks registers services with management's gRPC server.
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
if servers.mgmtSrv == nil {
return
}
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
if srv != nil {
instanceURL := srv.InstanceURL()
log.Infof("Relay server instance URL: %s", instanceURL.String())
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if srv != nil {
if err := srv.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
}
if mgmtSrv != nil {
log.Infof("shutting down management and signal servers")
if err := mgmtSrv.Stop(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
}
}
if metricsServer != nil {
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
}
return errs
}
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
srv, err := relayServer.NewServer(cfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create relay server: %w", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cfg.Relay.Stun.Enabled {
for _, port := range cfg.Relay.Stun.Ports {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
}
stunListeners = append(stunListeners, listener)
log.Infof("STUN server listening on UDP port %d", port)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
tlsCfg := cfg.Server.TLS
if tlsCfg.LetsEncrypt.AWSRoute53 {
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
r53 := encryption.Route53TLS{
DataDir: tlsCfg.LetsEncrypt.DataDir,
Email: tlsCfg.LetsEncrypt.Email,
Domains: tlsCfg.LetsEncrypt.Domains,
}
tc, err := r53.GetCertificate()
if err != nil {
return nil, false, err
}
return tc, true, nil
}
if cfg.HasLetsEncrypt() {
log.Infof("setting up TLS with Let's Encrypt")
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
if err != nil {
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
}
return certManager.TLSConfig(), true, nil
}
if cfg.HasTLSCert() {
log.Debugf("using file based TLS config")
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
if err != nil {
return nil, false, err
}
return tc, true, nil
}
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
// If no port specified, assume default
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtConfig,
dnsDomain,
singleAccModeDomain,
mgmtPort,
cfg.Server.MetricsPort,
mgmt.DisableAnonymousMetrics,
mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
true,
)
return mgmtSrv, nil
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
// Native gRPC traffic (HTTP/2 with gRPC content-type)
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
grpcServer.ServeHTTP(w, r)
// WebSocket proxy for Management gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(w, r)
// WebSocket proxy for Signal gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
if cfg.Signal.Enabled {
wsProxy.Handler().ServeHTTP(w, r)
} else {
http.Error(w, "Signal service not enabled", http.StatusNotFound)
}
// Relay WebSocket
case r.URL.Path == "/relay":
if relayAcceptFn != nil {
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
} else {
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)
}
})
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept relay ws connection: %s", err)
return
}
connRemoteAddr := r.RemoteAddr
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr)
acceptFn(conn)
}
// logConfig prints all configuration parameters for debugging
func logConfig(cfg *CombinedConfig) {
log.Info("=== Configuration ===")
logServerConfig(cfg)
logComponentsConfig(cfg)
logRelayConfig(cfg)
logManagementConfig(cfg)
log.Info("=== End Configuration ===")
}
func logServerConfig(cfg *CombinedConfig) {
log.Info("--- Server ---")
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
log.Infof(" Log level: %s", cfg.Server.LogLevel)
log.Infof(" Data dir: %s", cfg.Server.DataDir)
switch {
case cfg.HasTLSCert():
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
case cfg.HasLetsEncrypt():
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
default:
log.Info(" TLS: disabled (using reverse proxy)")
}
}
func logComponentsConfig(cfg *CombinedConfig) {
log.Info("--- Components ---")
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
}
func logRelayConfig(cfg *CombinedConfig) {
if !cfg.Relay.Enabled {
return
}
log.Info("--- Relay ---")
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
if cfg.Relay.Stun.Enabled {
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
} else {
log.Info(" STUN: disabled")
}
}
func logManagementConfig(cfg *CombinedConfig) {
if !cfg.Management.Enabled {
return
}
log.Info("--- Management ---")
log.Infof(" Data dir: %s", cfg.Management.DataDir)
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
if cfg.Server.Store.DSN != "" {
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
}
log.Info(" Auth (embedded IdP):")
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
log.Info(" Client settings:")
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
for _, s := range cfg.Management.Stuns {
log.Infof(" STUN: %s", s.URI)
}
if len(cfg.Management.Relays.Addresses) > 0 {
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
}
}
// logEnvVars logs all NB_ environment variables that are currently set
func logEnvVars() {
log.Info("=== Environment Variables ===")
found := false
for _, env := range os.Environ() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
value = maskSecret(value)
}
log.Infof(" %s=%s", key, value)
found = true
}
}
if !found {
log.Info(" (none set)")
}
log.Info("=== End Environment Variables ===")
}
// maskDSNPassword masks the password in a DSN string.
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
func maskDSNPassword(dsn string) string {
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
if strings.Contains(dsn, "password=") {
parts := strings.Fields(dsn)
for i, p := range parts {
if strings.HasPrefix(p, "password=") {
parts[i] = "password=****"
}
}
return strings.Join(parts, " ")
}
// URI format: "user:password@host..."
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
prefix := dsn[:atIdx]
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
}
}
return dsn
}
// maskSecret returns first 4 chars of secret followed by "..."
func maskSecret(secret string) string {
if len(secret) <= 4 {
return "****"
}
return secret[:4] + "..."
}

View File

@@ -1,219 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
tokenCmd.AddCommand(tokenCreateCmd, tokenListCmd, tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
// Load combined server YAML config
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
// Get datadir from config or override
datadir := cfg.Server.DataDir
if tokenDatadir != "" {
datadir = tokenDatadir
}
// Get store engine from config
storeEngine := types.Engine(cfg.Server.Store.Engine)
if storeEngine == "" {
storeEngine = "sqlite"
}
s, err := store.NewStore(ctx, storeEngine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!") //nolint:forbidigo
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
fmt.Println() //nolint:forbidigo
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -1,111 +0,0 @@
# NetBird Combined Server Configuration
# Copy this file to config.yaml and customize for your deployment
#
# This is a Management server with optional embedded Signal, Relay, and STUN services.
# By default, all services run locally. You can use external services instead by
# setting the corresponding override fields.
#
# Architecture:
# - Management: Always runs locally (this IS the management server)
# - Signal: Local by default; set 'signalUri' to use external (disables local)
# - Relay: Local by default; set 'relays' to use external (disables local)
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Public address that peers will use to connect to this server
# Used for relay connections and management DNS domain
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
exposedAddress: "https://server.mycompany.com:443"
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
# stunPorts:
# - 3478
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Shared secret for relay authentication (required when running local relay)
authSecret: "your-secret-key-here"
# Data directory for all services
dataDir: "/var/lib/netbird/"
# ============================================================================
# External Service Overrides (optional)
# Use these to point to external Signal, Relay, or STUN servers instead of
# running them locally. When set, the corresponding local service is disabled.
# ============================================================================
# External STUN servers - disables local STUN server
# stuns:
# - uri: "stun:stun.example.com:3478"
# - uri: "stun:stun.example.com:3479"
# External relay servers - disables local relay server
# relays:
# addresses:
# - "rels://relay.example.com:443"
# credentialsTTL: "12h"
# secret: "relay-shared-secret"
# External signal server - disables local signal server
# signalUri: "https://signal.example.com:443"
# ============================================================================
# Management Settings
# ============================================================================
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
# Embedded authentication/identity provider (Dex) configuration (always enabled)
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://server.mycompany.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.netbird.io/nb-auth"
- "https://app.netbird.io/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []

View File

@@ -1,115 +0,0 @@
# Simplified Combined NetBird Server Configuration
# Copy this file to config.yaml and customize for your deployment
# Server-wide settings
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Relay service configuration
relay:
# Enable/disable the relay service
enabled: true
# Public address that peers will use to connect to this relay
# Format: hostname:port or ip:port
exposedAddress: "relay.example.com:443"
# Shared secret for relay authentication (required when enabled)
authSecret: "your-secret-key-here"
# Log level for relay (reserved for future use, currently uses global log level)
logLevel: "info"
# Embedded STUN server (optional)
stun:
enabled: false
ports: [3478]
logLevel: "info"
# Signal service configuration
signal:
# Enable/disable the signal service
enabled: true
# Log level for signal (reserved for future use, currently uses global log level)
logLevel: "info"
# Management service configuration
management:
# Enable/disable the management service
enabled: true
# Data directory for management service
dataDir: "/var/lib/netbird/"
# DNS domain for the management server
dnsDomain: ""
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://management.example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
- "https://app.example.com/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# External STUN servers (for client config)
stuns: []
# - uri: "stun:stun.example.com:3478"
# External relay servers (for client config)
relays:
addresses: []
# - "rels://relay.example.com:443"
credentialsTTL: "12h"
secret: ""
# External signal server URI (for client config)
signalUri: ""
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings
reverseProxy:
trustedHTTPProxies: []
trustedHTTPProxiesCount: 0
trustedPeers: []

View File

@@ -1,13 +0,0 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/combined/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

4
go.mod
View File

@@ -42,7 +42,6 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1
github.com/creack/pty v1.1.24
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex/api/v2 v2.4.0
@@ -69,7 +68,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
@@ -168,6 +167,7 @@ require (
github.com/containerd/containerd v1.7.29 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/coreos/go-oidc/v3 v3.14.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect

4
go.sum
View File

@@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -327,60 +327,6 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error {
return nil
}
// HasNonLocalConnectors checks if there are any connectors other than the local connector.
func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) {
connectors, err := p.storage.ListConnectors(ctx)
if err != nil {
return false, fmt.Errorf("failed to list connectors: %w", err)
}
p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors))
for _, conn := range connectors {
p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name)
if conn.ID != "local" || conn.Type != "local" {
p.logger.Info("found non-local connector", "id", conn.ID)
return true, nil
}
}
p.logger.Info("no non-local connectors found")
return false, nil
}
// DisableLocalAuth removes the local (password) connector.
// Returns an error if no other connectors are configured.
func (p *Provider) DisableLocalAuth(ctx context.Context) error {
hasOthers, err := p.HasNonLocalConnectors(ctx)
if err != nil {
return err
}
if !hasOthers {
return fmt.Errorf("cannot disable local authentication: no other identity providers configured")
}
// Check if local connector exists
_, err = p.storage.GetConnector(ctx, "local")
if errors.Is(err, storage.ErrNotFound) {
// Already disabled
return nil
}
if err != nil {
return fmt.Errorf("failed to check local connector: %w", err)
}
// Delete the local connector
if err := p.storage.DeleteConnector(ctx, "local"); err != nil {
return fmt.Errorf("failed to delete local connector: %w", err)
}
p.logger.Info("local authentication disabled")
return nil
}
// EnableLocalAuth creates the local (password) connector if it doesn't exist.
func (p *Provider) EnableLocalAuth(ctx context.Context) error {
return ensureLocalConnector(ctx, p.storage)
}
// ensureStaticConnectors creates or updates static connectors in storage
func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error {
for _, conn := range connectors {

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +0,0 @@
FROM golang:1.25-bookworm AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y gcc libc6-dev && rm -rf /var/lib/apt/lists/*
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=1 GOOS=linux go build -ldflags="-s -w" -o netbird-mgmt ./management
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
COPY --from=builder /app/netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -19,8 +19,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -57,7 +55,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
}
@@ -135,35 +133,35 @@ var (
}
)
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
ApplyCommandLineOverrides(loadedConfig)
applyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
return nil, err
}
LogConfigInfo(loadedConfig)
logConfigInfo(loadedConfig)
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
return loadedConfig, nil
}
// ApplyCommandLineOverrides applies command-line flag overrides to the config
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
// applyCommandLineOverrides applies command-line flag overrides to the config
func applyCommandLineOverrides(cfg *nbconfig.Config) {
if mgmtLetsencryptDomain != "" {
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
}
@@ -176,9 +174,9 @@ func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
}
}
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -215,20 +213,17 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.AuthClientID = cfg.HttpConfig.AuthAudience
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
callbackURL := strings.TrimSuffix(cfg.HttpConfig.AuthIssuer, "/oauth2")
cfg.HttpConfig.AuthCallbackURL = callbackURL + types.ProxyCallbackEndpointFull
return nil
}
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" {
return nil
@@ -254,16 +249,16 @@ func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
return err
}
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
return nil
}
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
return nil
}
@@ -290,8 +285,8 @@ func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
return nil
}
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
if cfg.PKCEAuthorizationFlow == nil {
return
}
@@ -304,8 +299,8 @@ func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
// logConfigInfo logs informational messages about the loaded configuration
func logConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
@@ -314,8 +309,8 @@ func LogConfigInfo(cfg *nbconfig.Config) {
}
}
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}

View File

@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
t.Fatalf("failed to create config: %s", err)
}
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
if err != nil {
t.Fatalf("failed to load management config: %s", err)
}

View File

@@ -80,10 +80,4 @@ func init() {
migrationCmd.AddCommand(upCmd)
rootCmd.AddCommand(migrationCmd)
tokenCmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
tokenCmd.AddCommand(tokenCreateCmd)
tokenCmd.AddCommand(tokenListCmd)
tokenCmd.AddCommand(tokenRevokeCmd)
rootCmd.AddCommand(tokenCmd)
}

View File

@@ -1,209 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strconv"
"text/tabwriter"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
var (
tokenName string
tokenExpireIn string
tokenDatadir string
tokenCmd = &cobra.Command{
Use: "token",
Short: "Manage proxy access tokens",
Long: "Commands for creating, listing, and revoking proxy access tokens used by reverse proxy instances to authenticate with the management server.",
}
tokenCreateCmd = &cobra.Command{
Use: "create",
Short: "Create a new proxy access token",
Long: "Creates a new proxy access token. The plain text token is displayed only once at creation time.",
RunE: tokenCreateRun,
}
tokenListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all proxy access tokens",
Long: "Lists all proxy access tokens with their IDs, names, creation dates, expiration, and revocation status.",
RunE: tokenListRun,
}
tokenRevokeCmd = &cobra.Command{
Use: "revoke [token-id]",
Short: "Revoke a proxy access token",
Long: "Revokes a proxy access token by its ID. Revoked tokens can no longer be used for authentication.",
Args: cobra.ExactArgs(1),
RunE: tokenRevokeRun,
}
)
func init() {
tokenCmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
tokenCreateCmd.Flags().StringVar(&tokenName, "name", "", "Name for the token (required)")
tokenCreateCmd.Flags().StringVar(&tokenExpireIn, "expires-in", "", "Token expiration duration (e.g., 365d, 24h, 30d). Empty means no expiration")
tokenCreateCmd.MarkFlagRequired("name") //nolint
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
//nolint
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource)
config, err := loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}
func tokenCreateRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
expiresIn, err := parseDuration(tokenExpireIn)
if err != nil {
return fmt.Errorf("parse expiration: %w", err)
}
generated, err := types.CreateNewProxyAccessToken(tokenName, expiresIn, nil, "CLI")
if err != nil {
return fmt.Errorf("generate token: %w", err)
}
if err := s.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
return fmt.Errorf("save token: %w", err)
}
fmt.Println("Token created successfully!") //nolint:forbidigo
fmt.Printf("Token: %s\n", generated.PlainToken) //nolint:forbidigo
fmt.Println() //nolint:forbidigo
fmt.Println("IMPORTANT: Save this token now. It will not be shown again.") //nolint:forbidigo
fmt.Printf("Token ID: %s\n", generated.ID) //nolint:forbidigo
return nil
})
}
func tokenListRun(cmd *cobra.Command, _ []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokens, err := s.GetAllProxyAccessTokens(ctx, store.LockingStrengthNone)
if err != nil {
return fmt.Errorf("list tokens: %w", err)
}
if len(tokens) == 0 {
fmt.Println("No proxy access tokens found.") //nolint:forbidigo
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "ID\tNAME\tCREATED\tEXPIRES\tLAST USED\tREVOKED")
fmt.Fprintln(w, "--\t----\t-------\t-------\t---------\t-------")
for _, t := range tokens {
expires := "never"
if t.ExpiresAt != nil {
expires = t.ExpiresAt.Format("2006-01-02")
}
lastUsed := "never"
if t.LastUsed != nil {
lastUsed = t.LastUsed.Format("2006-01-02 15:04")
}
revoked := "no"
if t.Revoked {
revoked = "yes"
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
t.ID,
t.Name,
t.CreatedAt.Format("2006-01-02"),
expires,
lastUsed,
revoked,
)
}
w.Flush()
return nil
})
}
func tokenRevokeRun(cmd *cobra.Command, args []string) error {
return withTokenStore(cmd, func(ctx context.Context, s store.Store) error {
tokenID := args[0]
if err := s.RevokeProxyAccessToken(ctx, tokenID); err != nil {
return fmt.Errorf("revoke token: %w", err)
}
fmt.Printf("Token %s revoked successfully.\n", tokenID) //nolint:forbidigo
return nil
})
}
// parseDuration parses a duration string with support for days (e.g., "30d", "365d").
// An empty string returns zero duration (no expiration).
func parseDuration(s string) (time.Duration, error) {
if len(s) == 0 {
return 0, nil
}
if s[len(s)-1] == 'd' {
d, err := strconv.Atoi(s[:len(s)-1])
if err != nil {
return 0, fmt.Errorf("invalid day format: %s", s)
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return time.Duration(d) * 24 * time.Hour, nil
}
d, err := time.ParseDuration(s)
if err != nil {
return 0, err
}
if d <= 0 {
return 0, fmt.Errorf("duration must be positive: %s", s)
}
return d, nil
}

View File

@@ -1,101 +0,0 @@
package cmd
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseDuration(t *testing.T) {
tests := []struct {
name string
input string
expected time.Duration
wantErr bool
}{
{
name: "empty string returns zero",
input: "",
expected: 0,
},
{
name: "days suffix",
input: "30d",
expected: 30 * 24 * time.Hour,
},
{
name: "one day",
input: "1d",
expected: 24 * time.Hour,
},
{
name: "365 days",
input: "365d",
expected: 365 * 24 * time.Hour,
},
{
name: "hours via Go duration",
input: "24h",
expected: 24 * time.Hour,
},
{
name: "minutes via Go duration",
input: "30m",
expected: 30 * time.Minute,
},
{
name: "complex Go duration",
input: "1h30m",
expected: 90 * time.Minute,
},
{
name: "invalid day format",
input: "abcd",
wantErr: true,
},
{
name: "negative days",
input: "-1d",
wantErr: true,
},
{
name: "zero days",
input: "0d",
wantErr: true,
},
{
name: "non-numeric days",
input: "xyzd",
wantErr: true,
},
{
name: "negative Go duration",
input: "-24h",
wantErr: true,
},
{
name: "zero Go duration",
input: "0s",
wantErr: true,
},
{
name: "invalid Go duration",
input: "notaduration",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseDuration(tt.input)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -174,7 +174,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -248,10 +247,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
}(peer)
}
@@ -327,7 +323,6 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return fmt.Errorf("failed to get validated peers: %v", err)
}
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
@@ -375,10 +370,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
return nil
}
@@ -443,8 +435,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
@@ -788,7 +778,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
},
},
},
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
@@ -851,7 +840,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())

View File

@@ -25,14 +25,11 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
MessageType: network_map.MessageTypeNetworkMap,
}
}}
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
@@ -48,14 +45,11 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
MessageType: network_map.MessageTypeNetworkMap,
}
}}
peersUpdater.SendUpdate(context.Background(), peer, update2)
timeout := time.After(5 * time.Second)

View File

@@ -4,19 +4,6 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
// MessageType indicates the type of update message for debouncing strategy
type MessageType int
const (
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
// These updates can be safely debounced - only the latest state matters
MessageTypeNetworkMap MessageType = iota
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
// These updates should not be dropped as they contain time-sensitive information
MessageTypeControlConfig
)
type UpdateMessage struct {
Update *proto.SyncResponse
MessageType MessageType
Update *proto.SyncResponse
}

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -33,7 +32,6 @@ type Manager interface {
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
GetPeerID(ctx context.Context, peerKey string) (string, error)
CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error
}
type managerImpl struct {
@@ -184,36 +182,3 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
func (m *managerImpl) GetPeerID(ctx context.Context, peerKey string) (string, error) {
return m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
}
func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
existingPeerID, err := m.store.GetPeerIDByKey(ctx, store.LockingStrengthNone, peerKey)
if err == nil && existingPeerID != "" {
// Peer already exists
return nil
}
name := fmt.Sprintf("proxy-%s", xid.New().String())
peer := &peer.Peer{
Ephemeral: true,
ProxyMeta: peer.ProxyMeta{
Cluster: cluster,
Embedded: true,
},
Name: name,
Key: peerKey,
LoginExpirationEnabled: false,
InactivityExpirationEnabled: false,
Meta: peer.PeerSystemMeta{
Hostname: name,
GoOS: "proxy",
OS: "proxy",
},
}
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, false)
if err != nil {
return fmt.Errorf("failed to create proxy peer: %w", err)
}
return nil
}

View File

@@ -162,17 +162,3 @@ func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}
// CreateProxyPeer mocks base method.
func (m *MockManager) CreateProxyPeer(ctx context.Context, accountID string, peerKey string, cluster string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateProxyPeer", ctx, accountID, peerKey, cluster)
ret0, _ := ret[0].(error)
return ret0
}
// CreateProxyPeer indicates an expected call of CreateProxyPeer.
func (mr *MockManagerMockRecorder) CreateProxyPeer(ctx, accountID, peerKey, cluster interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProxyPeer", reflect.TypeOf((*MockManager)(nil).CreateProxyPeer), ctx, accountID, peerKey, cluster)
}

View File

@@ -1,105 +0,0 @@
package accesslogs
import (
"net"
"net/netip"
"time"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.ID = serviceLog.GetLogId()
a.ServiceID = serviceLog.GetServiceId()
a.Timestamp = serviceLog.GetTimestamp().AsTime()
a.Method = serviceLog.GetMethod()
a.Host = serviceLog.GetHost()
a.Path = serviceLog.GetPath()
a.Duration = time.Duration(serviceLog.GetDurationMs()) * time.Millisecond
a.StatusCode = int(serviceLog.GetResponseCode())
a.UserId = serviceLog.GetUserId()
a.AuthMethodUsed = serviceLog.GetAuthMechanism()
a.AccountID = serviceLog.GetAccountId()
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
}
}
if !serviceLog.GetAuthSuccess() {
a.Reason = "Authentication failed"
} else if serviceLog.GetResponseCode() >= 400 {
a.Reason = "Request failed"
}
}
// ToAPIResponse converts an AccessLogEntry to the API ProxyAccessLog type
func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
var sourceIP *string
if a.GeoLocation.ConnectionIP != nil {
ip := a.GeoLocation.ConnectionIP.String()
sourceIP = &ip
}
var reason *string
if a.Reason != "" {
reason = &a.Reason
}
var userID *string
if a.UserId != "" {
userID = &a.UserId
}
var authMethod *string
if a.AuthMethodUsed != "" {
authMethod = &a.AuthMethodUsed
}
var countryCode *string
if a.GeoLocation.CountryCode != "" {
countryCode = &a.GeoLocation.CountryCode
}
var cityName *string
if a.GeoLocation.CityName != "" {
cityName = &a.GeoLocation.CityName
}
return &api.ProxyAccessLog{
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
}
}

View File

@@ -1,124 +0,0 @@
package accesslogs
import (
"net/http"
"strconv"
"time"
)
const (
// DefaultPageSize is the default number of records per page
DefaultPageSize = 50
// MaxPageSize is the maximum number of records allowed per page
MaxPageSize = 100
)
// AccessLogFilter holds pagination and filtering parameters for access logs
type AccessLogFilter struct {
// Page is the current page number (1-indexed)
Page int
// PageSize is the number of records per page
PageSize int
// Filtering parameters
Search *string // General search across log ID, host, path, source IP, and user fields
SourceIP *string // Filter by source IP address
Host *string // Filter by host header
Path *string // Filter by request path (supports LIKE pattern)
UserID *string // Filter by authenticated user ID
UserEmail *string // Filter by user email (requires user lookup)
UserName *string // Filter by user name (requires user lookup)
Method *string // Filter by HTTP method
Status *string // Filter by status: "success" (2xx/3xx) or "failed" (1xx/4xx/5xx)
StatusCode *int // Filter by HTTP status code
StartDate *time.Time // Filter by timestamp >= start_date
EndDate *time.Time // Filter by timestamp <= end_date
}
// ParseFromRequest parses pagination and filter parameters from HTTP request query parameters
func (f *AccessLogFilter) ParseFromRequest(r *http.Request) {
queryParams := r.URL.Query()
f.Page = 1
if pageStr := queryParams.Get("page"); pageStr != "" {
if page, err := strconv.Atoi(pageStr); err == nil && page > 0 {
f.Page = page
}
}
f.PageSize = DefaultPageSize
if pageSizeStr := queryParams.Get("page_size"); pageSizeStr != "" {
if pageSize, err := strconv.Atoi(pageSizeStr); err == nil && pageSize > 0 {
f.PageSize = pageSize
if f.PageSize > MaxPageSize {
f.PageSize = MaxPageSize
}
}
}
if search := queryParams.Get("search"); search != "" {
f.Search = &search
}
if sourceIP := queryParams.Get("source_ip"); sourceIP != "" {
f.SourceIP = &sourceIP
}
if host := queryParams.Get("host"); host != "" {
f.Host = &host
}
if path := queryParams.Get("path"); path != "" {
f.Path = &path
}
if userID := queryParams.Get("user_id"); userID != "" {
f.UserID = &userID
}
if userEmail := queryParams.Get("user_email"); userEmail != "" {
f.UserEmail = &userEmail
}
if userName := queryParams.Get("user_name"); userName != "" {
f.UserName = &userName
}
if method := queryParams.Get("method"); method != "" {
f.Method = &method
}
if status := queryParams.Get("status"); status != "" {
f.Status = &status
}
if statusCodeStr := queryParams.Get("status_code"); statusCodeStr != "" {
if statusCode, err := strconv.Atoi(statusCodeStr); err == nil && statusCode > 0 {
f.StatusCode = &statusCode
}
}
if startDate := queryParams.Get("start_date"); startDate != "" {
parsedStartDate, err := time.Parse(time.RFC3339, startDate)
if err == nil {
f.StartDate = &parsedStartDate
}
}
if endDate := queryParams.Get("end_date"); endDate != "" {
parsedEndDate, err := time.Parse(time.RFC3339, endDate)
if err == nil {
f.EndDate = &parsedEndDate
}
}
}
// GetOffset calculates the database offset for pagination
func (f *AccessLogFilter) GetOffset() int {
return (f.Page - 1) * f.PageSize
}
// GetLimit returns the page size for database queries
func (f *AccessLogFilter) GetLimit() int {
return f.PageSize
}

View File

@@ -1,161 +0,0 @@
package accesslogs
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAccessLogFilter_ParseFromRequest(t *testing.T) {
tests := []struct {
name string
queryParams map[string]string
expectedPage int
expectedPageSize int
}{
{
name: "default values when no params provided",
queryParams: map[string]string{},
expectedPage: 1,
expectedPageSize: DefaultPageSize,
},
{
name: "valid page and page_size",
queryParams: map[string]string{
"page": "2",
"page_size": "25",
},
expectedPage: 2,
expectedPageSize: 25,
},
{
name: "page_size exceeds max, should cap at MaxPageSize",
queryParams: map[string]string{
"page": "1",
"page_size": "200",
},
expectedPage: 1,
expectedPageSize: MaxPageSize,
},
{
name: "invalid page number, should use default",
queryParams: map[string]string{
"page": "invalid",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "invalid page_size, should use default",
queryParams: map[string]string{
"page": "2",
"page_size": "invalid",
},
expectedPage: 2,
expectedPageSize: DefaultPageSize,
},
{
name: "zero page number, should use default",
queryParams: map[string]string{
"page": "0",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "negative page number, should use default",
queryParams: map[string]string{
"page": "-1",
"page_size": "10",
},
expectedPage: 1,
expectedPageSize: 10,
},
{
name: "zero page_size, should use default",
queryParams: map[string]string{
"page": "1",
"page_size": "0",
},
expectedPage: 1,
expectedPageSize: DefaultPageSize,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
q := req.URL.Query()
for key, value := range tt.queryParams {
q.Set(key, value)
}
req.URL.RawQuery = q.Encode()
filter := &AccessLogFilter{}
filter.ParseFromRequest(req)
assert.Equal(t, tt.expectedPage, filter.Page, "Page mismatch")
assert.Equal(t, tt.expectedPageSize, filter.PageSize, "PageSize mismatch")
})
}
}
func TestAccessLogFilter_GetOffset(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
expectedOffset int
}{
{
name: "first page",
page: 1,
pageSize: 50,
expectedOffset: 0,
},
{
name: "second page",
page: 2,
pageSize: 50,
expectedOffset: 50,
},
{
name: "third page with page size 25",
page: 3,
pageSize: 25,
expectedOffset: 50,
},
{
name: "page 10 with page size 10",
page: 10,
pageSize: 10,
expectedOffset: 90,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filter := &AccessLogFilter{
Page: tt.page,
PageSize: tt.pageSize,
}
offset := filter.GetOffset()
assert.Equal(t, tt.expectedOffset, offset)
})
}
}
func TestAccessLogFilter_GetLimit(t *testing.T) {
filter := &AccessLogFilter{
Page: 2,
PageSize: 25,
}
limit := filter.GetLimit()
assert.Equal(t, 25, limit, "GetLimit should return PageSize")
}

View File

@@ -1,10 +0,0 @@
package accesslogs
import (
"context"
)
type Manager interface {
SaveAccessLog(ctx context.Context, proxyLog *AccessLogEntry) error
GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *AccessLogFilter) ([]*AccessLogEntry, int64, error)
}

View File

@@ -1,64 +0,0 @@
package manager
import (
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
type handler struct {
manager accesslogs.Manager
}
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
}
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r)
logs, totalCount, err := h.manager.GetAllAccessLogs(r.Context(), userAuth.AccountId, userAuth.UserId, &filter)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiLogs := make([]api.ProxyAccessLog, 0, len(logs))
for _, log := range logs {
apiLogs = append(apiLogs, *log.ToAPIResponse())
}
response := &api.ProxyAccessLogsResponse{
Data: apiLogs,
Page: filter.Page,
PageSize: filter.PageSize,
TotalRecords: int(totalCount),
TotalPages: getTotalPageCount(int(totalCount), filter.PageSize),
}
util.WriteJSONObject(r.Context(), w, response)
}
// getTotalPageCount calculates the total number of pages
func getTotalPageCount(totalCount, pageSize int) int {
if pageSize <= 0 {
return 0
}
return (totalCount + pageSize - 1) / pageSize
}

View File

@@ -1,108 +0,0 @@
package manager
import (
"context"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
geo geolocation.Geolocation
}
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
geo: geo,
}
}
// SaveAccessLog saves an access log entry to the database after enriching it
func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.AccessLogEntry) error {
if m.geo != nil && logEntry.GeoLocation.ConnectionIP != nil {
location, err := m.geo.Lookup(logEntry.GeoLocation.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for access log source IP [%s]: %v", logEntry.GeoLocation.ConnectionIP.String(), err)
} else {
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
logEntry.GeoLocation.CityName = location.City.Names.En
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
}
}
if err := m.store.CreateAccessLog(ctx, logEntry); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"service_id": logEntry.ServiceID,
"method": logEntry.Method,
"host": logEntry.Host,
"path": logEntry.Path,
"status": logEntry.StatusCode,
}).Errorf("failed to save access log: %v", err)
return err
}
return nil
}
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}
if !ok {
return nil, 0, status.NewPermissionDeniedError()
}
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
}
logs, totalCount, err := m.store.GetAccountAccessLogs(ctx, store.LockingStrengthNone, accountID, *filter)
if err != nil {
return nil, 0, err
}
return logs, totalCount, nil
}
// resolveUserFilters converts user email/name filters to user ID filter
func (m *managerImpl) resolveUserFilters(ctx context.Context, accountID string, filter *accesslogs.AccessLogFilter) error {
if filter.UserEmail == nil && filter.UserName == nil {
return nil
}
users, err := m.store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
var matchingUserIDs []string
for _, user := range users {
if filter.UserEmail != nil && strings.Contains(strings.ToLower(user.Email), strings.ToLower(*filter.UserEmail)) {
matchingUserIDs = append(matchingUserIDs, user.Id)
continue
}
if filter.UserName != nil && strings.Contains(strings.ToLower(user.Name), strings.ToLower(*filter.UserName)) {
matchingUserIDs = append(matchingUserIDs, user.Id)
}
}
if len(matchingUserIDs) > 0 {
filter.UserID = &matchingUserIDs[0]
}
return nil
}

View File

@@ -1,17 +0,0 @@
package domain
type Type string
const (
TypeFree Type = "free"
TypeCustom Type = "custom"
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
TargetCluster string // The proxy cluster this domain should be validated against
Type Type `gorm:"-"`
Validated bool
}

View File

@@ -1,12 +0,0 @@
package domain
import (
"context"
)
type Manager interface {
GetDomains(ctx context.Context, accountID, userID string) ([]*Domain, error)
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
}

View File

@@ -1,136 +0,0 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager Manager
}
func RegisterEndpoints(router *mux.Router, manager Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
}
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
switch t {
case domain.TypeCustom:
return api.ReverseProxyDomainTypeCustom
case domain.TypeFree:
return api.ReverseProxyDomainTypeFree
}
// By default return as a "free" domain as that is more restrictive.
// TODO: is this correct?
return api.ReverseProxyDomainTypeFree
}
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
resp := api.ReverseProxyDomain{
Domain: d.Domain,
Id: d.ID,
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster
}
return resp
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ret := make([]api.ReverseProxyDomain, 0)
for _, d := range domains {
ret = append(ret, domainToApi(d))
}
util.WriteJSONObject(r.Context(), w, ret)
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiReverseProxiesDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, req.Domain, req.TargetCluster)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
return
}
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
return
}
go h.manager.ValidateDomain(r.Context(), userAuth.AccountId, userAuth.UserId, domainID)
w.WriteHeader(http.StatusAccepted)
}

View File

@@ -1,279 +0,0 @@
package manager
import (
"context"
"fmt"
"net"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
type store interface {
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, targetCluster string, validated bool) (*domain.Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
}
type Manager struct {
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
permissionsManager permissions.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
permissionsManager: permissionsManager,
}
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
}
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
for _, cluster := range allowList {
ret = append(ret, &domain.Domain{
Domain: cluster,
AccountID: accountID,
Type: domain.TypeFree,
Validated: true,
})
}
// Add custom domains.
for _, d := range domains {
ret = append(ret, &domain.Domain{
ID: d.ID,
Domain: d.Domain,
AccountID: accountID,
TargetCluster: d.TargetCluster,
Type: domain.TypeCustom,
Validated: d.Validated,
})
}
return ret, nil
}
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
clusterValid = true
break
}
}
if !clusterValid {
return nil, fmt.Errorf("target cluster %s is not available", targetCluster)
}
// Attempt an initial validation against the specified cluster only
var validated bool
if m.validator.IsValid(ctx, domainName, []string{targetCluster}) {
validated = true
}
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, targetCluster, validated)
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
return d, nil
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
}
return nil
}
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).Info("starting domain validation")
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("get custom domain from store")
return
}
// Validate only against the domain's target cluster
targetCluster := d.TargetCluster
if targetCluster == "" {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Warn("domain has no target cluster set, skipping validation")
return
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"targetCluster": targetCluster,
}).Info("validating domain against target cluster")
if m.validator.IsValid(context.Background(), d.Domain, []string{targetCluster}) {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).Info("domain validated successfully")
d.Validated = true
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
}).WithError(err).Error("update custom domain in store")
return
}
} else {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
"domain": d.Domain,
"targetCluster": targetCluster,
}).Warn("domain validation failed - CNAME does not match target cluster")
}
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
}
return reverseProxyAddresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}
if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok {
return cluster, nil
}
customDomains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return "", fmt.Errorf("list custom domains: %w", err)
}
targetCluster, valid := extractClusterFromCustomDomains(domain, customDomains)
if valid {
return targetCluster, nil
}
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
for _, customDomain := range customDomains {
if strings.HasSuffix(domain, "."+customDomain.Domain) {
return customDomain.TargetCluster, true
}
}
return "", false
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
// Free domains have the format: <name>.<nonce>.<cluster> (e.g., myapp.abc123.eu.proxy.netbird.io)
// It matches the domain suffix against available clusters and returns the matching cluster.
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
for _, cluster := range availableClusters {
if strings.HasSuffix(domain, "."+cluster) {
return cluster, true
}
}
return "", false
}

View File

@@ -1,88 +0,0 @@
package domain
import (
"context"
"net"
"strings"
log "github.com/sirupsen/logrus"
)
type resolver interface {
LookupCNAME(context.Context, string) (string, error)
}
type Validator struct {
Resolver resolver
}
// NewValidator initializes a validator with a specific DNS Resolver.
// If a Validator is used without specifying a Resolver, then it will
// use the net.DefaultResolver.
func NewValidator(resolver resolver) *Validator {
return &Validator{
Resolver: resolver,
}
}
// IsValid looks up the CNAME record for the passed domain with a prefix
// and compares it against the acceptable domains.
// If the returned CNAME matches any accepted domain, it will return true,
// otherwise, including in the event of a DNS error, it will return false.
// The comparison is very simple, so wildcards will not match if included
// in the acceptable domain list.
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
_, valid := v.ValidateWithCluster(ctx, domain, accept)
return valid
}
// ValidateWithCluster validates a custom domain and returns the matched cluster address.
// Returns the cluster address and true if valid, or empty string and false if invalid.
func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) {
if v.Resolver == nil {
v.Resolver = net.DefaultResolver
}
lookupDomain := "validation." + domain
log.WithFields(log.Fields{
"domain": domain,
"lookupDomain": lookupDomain,
"acceptList": accept,
}).Debug("looking up CNAME for domain validation")
cname, err := v.Resolver.LookupCNAME(ctx, lookupDomain)
if err != nil {
log.WithFields(log.Fields{
"domain": domain,
"lookupDomain": lookupDomain,
}).WithError(err).Warn("CNAME lookup failed for domain validation")
return "", false
}
nakedCNAME := strings.TrimSuffix(cname, ".")
log.WithFields(log.Fields{
"domain": domain,
"cname": cname,
"nakedCNAME": nakedCNAME,
"acceptList": accept,
}).Debug("CNAME lookup result for domain validation")
for _, acceptDomain := range accept {
normalizedAccept := strings.TrimSuffix(acceptDomain, ".")
if nakedCNAME == normalizedAccept {
log.WithFields(log.Fields{
"domain": domain,
"cname": nakedCNAME,
"cluster": acceptDomain,
}).Info("domain CNAME matched cluster")
return acceptDomain, true
}
}
log.WithFields(log.Fields{
"domain": domain,
"cname": nakedCNAME,
"acceptList": accept,
}).Warn("domain CNAME does not match any accepted cluster")
return "", false
}

View File

@@ -1,56 +0,0 @@
package domain_test
import (
"context"
"testing"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
)
type resolver struct {
CNAME string
}
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
return r.CNAME, nil
}
func TestIsValid(t *testing.T) {
tests := map[string]struct {
resolver interface {
LookupCNAME(context.Context, string) (string, error)
}
domain string
accept []string
expect bool
}{
"match": {
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
domain: "foo.example.com",
accept: []string{"bar.example.com"},
expect: true,
},
"no match": {
resolver: resolver{"invalid"},
domain: "foo.example.com",
accept: []string{"bar.example.com"},
expect: false,
},
"accept trailing dot": {
resolver: resolver{"bar.example.com."},
domain: "foo.example.com",
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
expect: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
validator := domain.NewValidator(test.resolver)
actual := validator.IsValid(t.Context(), test.domain, test.accept)
if test.expect != actual {
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
}
})
}
}

View File

@@ -1,23 +0,0 @@
package reverseproxy
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
)
type Manager interface {
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
}

View File

@@ -1,225 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateService mocks base method.
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateService indicates an expected call of CreateService.
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteService indicates an expected call of DeleteService.
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
}
// GetAccountServices mocks base method.
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountServices indicates an expected call of GetAccountServices.
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllServices indicates an expected call of GetAllServices.
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGlobalServices indicates an expected call of GetGlobalServices.
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
}
// GetService mocks base method.
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetService indicates an expected call of GetService.
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
}
// GetServiceByID mocks base method.
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByID indicates an expected call of GetServiceByID.
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
}
// GetServiceIDByTargetID mocks base method.
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
}
// ReloadAllServicesForAccount mocks base method.
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
}
// ReloadService mocks base method.
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadService indicates an expected call of ReloadService.
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
return ret0
}
// SetStatus indicates an expected call of SetStatus.
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateService indicates an expected call of UpdateService.
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}

View File

@@ -1,170 +0,0 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager reverseproxy.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domainmanager.RegisterEndpoints(domainRouter, domainManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiServices := make([]*api.Service, 0, len(allServices))
for _, service := range allServices {
apiServices = append(apiServices, service.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiServices)
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
service := new(reverseproxy.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdService, err := h.manager.CreateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
service, err := h.manager.GetService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
service := new(reverseproxy.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedService, err := h.manager.UpdateService(r.Context(), userAuth.AccountId, userAuth.UserId, service)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
return
}
if err := h.manager.DeleteService(r.Context(), userAuth.AccountId, userAuth.UserId, serviceID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -1,500 +0,0 @@
package manager
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
}
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
clusterDeriver ClusterDeriver
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
clusterDeriver: clusterDeriver,
}
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var proxyCluster string
if m.clusterDeriver != nil {
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
}
service.AccountID = accountID
service.ProxyCluster = proxyCluster
service.InitNewRecord()
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
// Generate session JWT signing keys
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return nil, fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Check for duplicate domain
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
}
if existingService != nil {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var oldCluster string
var domainChanged bool
var serviceEnabledChanged bool
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
oldCluster = existingService.ProxyCluster
if existingService.Domain != service.Domain {
domainChanged = true
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("check existing service: %w", err)
}
}
if conflictService != nil && conflictService.ID != service.ID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
}
service.ProxyCluster = newCluster
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
serviceEnabledChanged = existingService.Enabled != service.Enabled
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case domainChanged && oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
}
}
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
return nil
})
if err != nil {
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.CertificateIssuedAt = time.Now()
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
}
return nil
})
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
service.Meta.Status = string(status)
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("failed to update service status: %w", err)
}
return nil
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, service.AccountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
}
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return "", nil
}
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ServiceID, nil
}

View File

@@ -1,463 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Operation string
const (
Create Operation = "create"
Update Operation = "update"
Delete Operation = "delete"
)
type ProxyStatus string
const (
StatusPending ProxyStatus = "pending"
StatusActive ProxyStatus = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
Enabled bool `json:"enabled"`
Password string `json:"password"`
}
type PINAuthConfig struct {
Enabled bool `json:"enabled"`
Pin string `json:"pin"`
}
type BearerAuthConfig struct {
Enabled bool `json:"enabled"`
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
})
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
case Update:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
}
return nil
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,405 +0,0 @@
package reverseproxy
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}

View File

@@ -1,69 +0,0 @@
package sessionkey
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/netbird/proxy/auth"
)
type KeyPair struct {
PrivateKey string
PublicKey string
}
type Claims struct {
jwt.RegisteredClaims
Method auth.Method `json:"method"`
}
func GenerateKeyPair() (*KeyPair, error) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate ed25519 key: %w", err)
}
return &KeyPair{
PrivateKey: base64.StdEncoding.EncodeToString(priv),
PublicKey: base64.StdEncoding.EncodeToString(pub),
}, nil
}
func SignToken(privKeyB64, userID, domain string, method auth.Method, expiration time.Duration) (string, error) {
privKeyBytes, err := base64.StdEncoding.DecodeString(privKeyB64)
if err != nil {
return "", fmt.Errorf("decode private key: %w", err)
}
if len(privKeyBytes) != ed25519.PrivateKeySize {
return "", fmt.Errorf("invalid private key size: got %d, want %d", len(privKeyBytes), ed25519.PrivateKeySize)
}
privKey := ed25519.PrivateKey(privKeyBytes)
now := time.Now()
claims := Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: auth.SessionJWTIssuer,
Subject: userID,
Audience: jwt.ClaimStrings{domain},
ExpiresAt: jwt.NewNumericDate(now.Add(expiration)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
Method: method,
}
token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims)
signedToken, err := token.SignedString(privKey)
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return signedToken, nil
}

View File

@@ -21,8 +21,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
@@ -94,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -122,13 +120,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
realip.WithTrustedProxiesCount(trustedProxiesCount),
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
}
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
s.proxyAuthClose = proxyAuthClose
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.Config.HttpConfig.LetsEncryptDomain != "" {
@@ -154,53 +150,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
log.Info("ProxyService registered on gRPC server")
return gRPCAPIHandler
})
}
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager())
})
return proxyService
})
}
func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
return Create(s, func() nbgrpc.ProxyOIDCConfig {
return nbgrpc.ProxyOIDCConfig{
Issuer: s.Config.HttpConfig.AuthIssuer,
// todo: double check auth clientID value
ClientID: s.Config.HttpConfig.AuthClientID, // Reuse dashboard client
Scopes: []string{"openid", "profile", "email"},
CallbackURL: s.Config.HttpConfig.AuthCallbackURL,
HMACKey: []byte(s.Config.DataStoreEncryptionKey), // Use the datastore encryption key for OIDC state HMACs, this should ensure all management instances are using the same key.
Audience: s.Config.HttpConfig.AuthAudience,
KeysLocation: s.Config.HttpConfig.AuthKeysLocation,
}
})
}
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
log.Info("One-time token store initialized for proxy authentication")
return tokenStore
})
}
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
return accessLogManager
})
}
func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
// Load server's certificate and private key
serverCert, err := tls.LoadX509KeyPair(certFile, certKey)

View File

@@ -100,8 +100,6 @@ type HttpServerConfig struct {
CertFile string
// CertKey is the location of the certificate private key
CertKey string
// AuthClientID is the client id used for proxy SSO auth
AuthClientID string
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
AuthAudience string
// CLIAuthAudience identifies the client app recipients that the JWT is intended for (aud in JWT)
@@ -119,8 +117,6 @@ type HttpServerConfig struct {
IdpSignKeyRefreshEnabled bool
// Extra audience
ExtraAuthAudience string
// AuthCallbackDomain contains the callback domain
AuthCallbackURL string
}
// Host represents a Netbird host (e.g. STUN, TURN, Signal)

View File

@@ -8,9 +8,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -72,14 +69,7 @@ func (s *BaseServer) UsersManager() users.Manager {
func (s *BaseServer) SettingsManager() settings.Manager {
return Create(s, func() settings.Manager {
extraSettingsManager := integrations.NewManager(s.EventStore())
idpConfig := settings.IdpConfig{}
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpConfig.EmbeddedIdpEnabled = true
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
}
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager())
})
}
@@ -101,11 +91,6 @@ func (s *BaseServer) AccountManager() account.Manager {
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager())
})
return accountManager
})
}
@@ -162,7 +147,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager())
})
}
@@ -189,16 +174,3 @@ func (s *BaseServer) RecordsManager() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
})
}
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
return &m
})
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
@@ -20,7 +21,6 @@ import (
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util/wsproxy"
@@ -58,8 +58,6 @@ type BaseServer struct {
mgmtMetricsPort int
mgmtPort int
proxyAuthClose func()
listener net.Listener
certManager *autocert.Manager
update *version.Update
@@ -140,14 +138,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
go metricsWorker.Run(srvCtx)
}
// Run afterInit hooks before starting any servers
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -188,6 +178,12 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
}
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
@@ -219,11 +215,6 @@ func (s *BaseServer) Stop() error {
_ = s.certManager.Listener().Close()
}
s.GRPCServer().Stop()
s.ReverseProxyGRPCServer().Close()
if s.proxyAuthClose != nil {
s.proxyAuthClose()
s.proxyAuthClose = nil
}
_ = s.Store().Close(ctx)
_ = s.EventStore().Close(ctx)
if s.update != nil {
@@ -264,23 +255,7 @@ func (s *BaseServer) SetContainer(key string, container any) {
log.Tracef("container with key %s set successfully", key)
}
// SetHandlerFunc allows overriding the default HTTP handler function.
// This is useful for multiplexing additional services on the same port.
func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
s.container["customHandler"] = handler
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
log.Tracef("using custom handler")
return handler
}
}
// Use default handler
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

View File

@@ -1,167 +0,0 @@
package grpc
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ServiceID string
AccountID string
ExpiresAt time.Time
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
// of expired tokens. The cleanupInterval determines how often expired
// tokens are removed from memory.
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
store := &OneTimeTokenStore{
tokens: make(map[string]*tokenMetadata),
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}),
}
// Start background cleanup goroutine
go store.cleanupExpired()
return store
}
// GenerateToken creates a new cryptographically secure one-time token
// with the specified TTL. The token is associated with a specific
// accountID and serviceID for validation purposes.
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err)
}
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes)
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl)
return token, nil
}
// ValidateAndConsume verifies the token against the provided accountID and
// serviceID, checks expiration, and then deletes it to enforce single-use.
//
// This method uses constant-time comparison to prevent timing attacks.
//
// Returns nil on success, or an error if:
// - Token doesn't exist
// - Token has expired
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
serviceID, accountID)
return fmt.Errorf("invalid token")
}
// Check expiration
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
return fmt.Errorf("token expired")
}
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch")
}
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
log.Infof("Token validated and consumed for proxy %s in account %s",
serviceID, accountID)
return nil
}
// cleanupExpired removes expired tokens in the background to prevent memory leaks
func (s *OneTimeTokenStore) cleanupExpired() {
for {
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,234 +0,0 @@
package grpc
import (
"context"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
// lastUsedUpdateInterval is the minimum interval between last_used updates for the same token.
lastUsedUpdateInterval = time.Minute
// lastUsedCleanupInterval is how often stale lastUsed entries are removed.
lastUsedCleanupInterval = 2 * time.Minute
)
type proxyTokenContextKey struct{}
// ProxyTokenContextKey is the typed key used to store validated token info in context.
var ProxyTokenContextKey = proxyTokenContextKey{}
// proxyTokenID identifies a proxy access token by its database ID.
type proxyTokenID = string
// proxyTokenStore defines the store interface needed for token validation
type proxyTokenStore interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength store.LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
}
// proxyAuthInterceptor holds state for proxy authentication interceptors.
type proxyAuthInterceptor struct {
store proxyTokenStore
failureLimiter *authFailureLimiter
// lastUsedMu protects lastUsedTimes
lastUsedMu sync.Mutex
lastUsedTimes map[proxyTokenID]time.Time
cancel context.CancelFunc
}
func newProxyAuthInterceptor(tokenStore proxyTokenStore) *proxyAuthInterceptor {
ctx, cancel := context.WithCancel(context.Background())
i := &proxyAuthInterceptor{
store: tokenStore,
failureLimiter: newAuthFailureLimiter(),
lastUsedTimes: make(map[proxyTokenID]time.Time),
cancel: cancel,
}
go i.lastUsedCleanupLoop(ctx)
return i
}
// NewProxyAuthInterceptors creates gRPC unary and stream interceptors that validate proxy access tokens.
// They only intercept ProxyService methods. Both interceptors share state for last-used and failure rate limiting.
// The returned close function must be called on shutdown to stop background goroutines.
func NewProxyAuthInterceptors(tokenStore proxyTokenStore) (grpc.UnaryServerInterceptor, grpc.StreamServerInterceptor, func()) {
interceptor := newProxyAuthInterceptor(tokenStore)
unary := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(ctx, req)
}
token, err := interceptor.validateProxyToken(ctx)
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ctx).Warnf("proxy auth failed: %v", err)
return nil, err
}
ctx = context.WithValue(ctx, ProxyTokenContextKey, token)
return handler(ctx, req)
}
stream := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if !strings.HasPrefix(info.FullMethod, "/management.ProxyService/") {
return handler(srv, ss)
}
token, err := interceptor.validateProxyToken(ss.Context())
if err != nil {
// Log auth failures explicitly; gRPC doesn't log these by default.
log.WithContext(ss.Context()).Warnf("proxy auth failed: %v", err)
return err
}
ctx := context.WithValue(ss.Context(), ProxyTokenContextKey, token)
wrapped := &wrappedServerStream{
ServerStream: ss,
ctx: ctx,
}
return handler(srv, wrapped)
}
return unary, stream, interceptor.close
}
func (i *proxyAuthInterceptor) validateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
clientIP := peerIPFromContext(ctx)
if clientIP != "" && i.failureLimiter.isLimited(clientIP) {
return nil, status.Errorf(codes.ResourceExhausted, "too many failed authentication attempts")
}
token, err := i.doValidateProxyToken(ctx)
if err != nil {
if clientIP != "" {
i.failureLimiter.recordFailure(clientIP)
}
return nil, err
}
i.maybeUpdateLastUsed(ctx, token.ID)
return token, nil
}
func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types.ProxyAccessToken, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
}
authValues := md.Get("authorization")
if len(authValues) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "missing authorization header")
}
authValue := authValues[0]
if !strings.HasPrefix(authValue, "Bearer ") {
return nil, status.Errorf(codes.Unauthenticated, "invalid authorization format")
}
plainToken := types.PlainProxyToken(strings.TrimPrefix(authValue, "Bearer "))
if err := plainToken.Validate(); err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token format")
}
token, err := i.store.GetProxyAccessTokenByHashedToken(ctx, store.LockingStrengthNone, plainToken.Hash())
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}
return token, nil
}
// maybeUpdateLastUsed updates the last_used timestamp if enough time has passed since the last update.
func (i *proxyAuthInterceptor) maybeUpdateLastUsed(ctx context.Context, tokenID string) {
now := time.Now()
i.lastUsedMu.Lock()
lastUpdate, exists := i.lastUsedTimes[tokenID]
if exists && now.Sub(lastUpdate) < lastUsedUpdateInterval {
i.lastUsedMu.Unlock()
return
}
i.lastUsedTimes[tokenID] = now
i.lastUsedMu.Unlock()
if err := i.store.MarkProxyAccessTokenUsed(ctx, tokenID); err != nil {
log.WithContext(ctx).Debugf("failed to mark proxy token as used: %v", err)
}
}
func (i *proxyAuthInterceptor) lastUsedCleanupLoop(ctx context.Context) {
ticker := time.NewTicker(lastUsedCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
i.cleanupStaleLastUsed()
case <-ctx.Done():
return
}
}
}
// cleanupStaleLastUsed removes entries older than 2x the update interval.
func (i *proxyAuthInterceptor) cleanupStaleLastUsed() {
i.lastUsedMu.Lock()
defer i.lastUsedMu.Unlock()
now := time.Now()
staleThreshold := 2 * lastUsedUpdateInterval
for id, lastUpdate := range i.lastUsedTimes {
if now.Sub(lastUpdate) > staleThreshold {
delete(i.lastUsedTimes, id)
}
}
}
func (i *proxyAuthInterceptor) close() {
i.cancel()
i.failureLimiter.stop()
}
// GetProxyTokenFromContext retrieves the validated proxy token from the context
func GetProxyTokenFromContext(ctx context.Context) *types.ProxyAccessToken {
token, ok := ctx.Value(ProxyTokenContextKey).(*types.ProxyAccessToken)
if !ok {
return nil
}
return token
}
// wrappedServerStream wraps a grpc.ServerStream to provide a custom context
type wrappedServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedServerStream) Context() context.Context {
return w.ctx
}

View File

@@ -1,134 +0,0 @@
package grpc
import (
"context"
"net"
"sync"
"time"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"golang.org/x/time/rate"
"google.golang.org/grpc/peer"
)
const (
// proxyAuthFailureBurst is the maximum number of failed attempts before rate limiting kicks in.
proxyAuthFailureBurst = 5
// proxyAuthLimiterCleanup is how often stale limiters are removed.
proxyAuthLimiterCleanup = 5 * time.Minute
// proxyAuthLimiterTTL is how long a limiter is kept after the last failure.
proxyAuthLimiterTTL = 15 * time.Minute
)
// defaultProxyAuthFailureRate is the token replenishment rate for failed auth attempts.
// One token every 12 seconds = 5 per minute.
var defaultProxyAuthFailureRate = rate.Every(12 * time.Second)
// clientIP identifies a client by its IP address for rate limiting purposes.
type clientIP = string
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// authFailureLimiter tracks per-IP rate limits for failed proxy authentication attempts.
type authFailureLimiter struct {
mu sync.Mutex
limiters map[clientIP]*limiterEntry
failureRate rate.Limit
cancel context.CancelFunc
}
func newAuthFailureLimiter() *authFailureLimiter {
return newAuthFailureLimiterWithRate(defaultProxyAuthFailureRate)
}
func newAuthFailureLimiterWithRate(failureRate rate.Limit) *authFailureLimiter {
ctx, cancel := context.WithCancel(context.Background())
l := &authFailureLimiter{
limiters: make(map[clientIP]*limiterEntry),
failureRate: failureRate,
cancel: cancel,
}
go l.cleanupLoop(ctx)
return l
}
// isLimited returns true if the given IP has exhausted its failure budget.
func (l *authFailureLimiter) isLimited(ip clientIP) bool {
l.mu.Lock()
defer l.mu.Unlock()
entry, exists := l.limiters[ip]
if !exists {
return false
}
return entry.limiter.Tokens() < 1
}
// recordFailure consumes a token from the rate limiter for the given IP.
func (l *authFailureLimiter) recordFailure(ip clientIP) {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
entry, exists := l.limiters[ip]
if !exists {
entry = &limiterEntry{
limiter: rate.NewLimiter(l.failureRate, proxyAuthFailureBurst),
}
l.limiters[ip] = entry
}
entry.lastAccess = now
entry.limiter.Allow()
}
func (l *authFailureLimiter) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(proxyAuthLimiterCleanup)
defer ticker.Stop()
for {
select {
case <-ticker.C:
l.cleanup()
case <-ctx.Done():
return
}
}
}
func (l *authFailureLimiter) cleanup() {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for ip, entry := range l.limiters {
if now.Sub(entry.lastAccess) > proxyAuthLimiterTTL {
delete(l.limiters, ip)
}
}
}
func (l *authFailureLimiter) stop() {
l.cancel()
}
// peerIPFromContext extracts the client IP from the gRPC context.
// Uses realip (from trusted proxy headers) first, falls back to the transport peer address.
func peerIPFromContext(ctx context.Context) clientIP {
if addr, ok := realip.FromContext(ctx); ok {
return addr.String()
}
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}
return ""
}

View File

@@ -1,98 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -1,381 +0,0 @@
package grpc
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
err error
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type mockUsersManager struct {
users map[string]*types.User
err error
}
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[userID]
if !ok {
return nil, errors.New("user not found")
}
return user, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
users map[string]*types.User
proxyErr error
userErr error
expectErr bool
expectErrMsg string
}{
{
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
expectErr: true,
expectErrMsg: "user not found",
},
{
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
},
expectErr: true,
expectErrMsg: "not in allowed groups",
},
{
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
},
expectErr: false,
},
{
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
},
expectErr: false,
},
{
name: "proxy manager error",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: nil,
proxyErr: errors.New("database error"),
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account services",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
usersManager: &mockUsersManager{
users: tt.users,
err: tt.userErr,
},
}
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
err error
expectProxy bool
expectErr bool
}{
{
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
},
},
expectProxy: true,
expectErr: false,
},
{
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
expectErr: true,
},
{
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
expectProxy: false,
expectErr: true,
},
{
name: "manager error",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: nil,
err: errors.New("database error"),
expectProxy: false,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},
}
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)
assert.Nil(t, proxy)
} else {
require.NoError(t, err)
require.NotNil(t, proxy)
assert.Equal(t, tt.domain, proxy.Domain)
}
})
}
}

View File

@@ -1,232 +0,0 @@
package grpc
import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
select {
case msg := <-ch:
return msg
case <-time.After(time.Second):
return nil
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080/"},
},
}
s.SendServiceUpdateToCluster(update, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
// All tokens must be unique
tokenSet := make(map[string]struct{})
for i, tok := range tokens {
_, exists := tokenSet[tok]
assert.False(t, exists, "proxy %d got duplicate token", i)
tokenSet[tok] = struct{}{}
}
// Each token must be independently consumable
for i, tok := range tokens {
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
assert.NoError(t, err, "proxy %d token should validate successfully", i)
}
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(update, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
// Both tokens should validate
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
}
// generateState creates a state using the same format as GetOIDCURL.
func generateState(s *ProxyServiceServer, redirectURL string) string {
nonce := make([]byte, 16)
_, _ = rand.Read(nonce)
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
payload := redirectURL + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
}
func TestOAuthState_NeverTheSame(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
redirectURL := "https://app.example.com/callback"
// Generate 100 states for the same redirect URL
states := make(map[string]bool)
for i := 0; i < 100; i++ {
state := generateState(s, redirectURL)
// State must have 3 parts: base64(url)|nonce|hmac
parts := strings.Split(state, "|")
require.Equal(t, 3, len(parts), "state must have 3 parts")
// State must be unique
require.False(t, states[state], "state %d is a duplicate", i)
states[state] = true
}
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -77,9 +77,8 @@ type Server struct {
oAuthConfigProvider idp.OAuthConfigProvider
syncSem atomic.Int32
syncLimEnabled bool
syncLim int32
syncSem atomic.Int32
syncLim int32
}
// NewServer creates a new Management server
@@ -109,7 +108,6 @@ func NewServer(
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
syncLim := int32(defaultSyncLim)
syncLimEnabled := true
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
@@ -117,9 +115,6 @@ func NewServer(
} else {
//nolint:gosec
syncLim = int32(syncLimParsed)
if syncLim < 0 {
syncLimEnabled = false
}
}
}
@@ -139,8 +134,7 @@ func NewServer(
loginFilter: newLoginFilter(),
syncLim: syncLim,
syncLimEnabled: syncLimEnabled,
syncLim: syncLim,
}, nil
}
@@ -218,7 +212,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim {
if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
@@ -300,7 +294,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
@@ -311,7 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}
@@ -319,7 +313,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}
@@ -336,7 +330,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
@@ -404,20 +398,11 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
// It implements a backpressure mechanism that sends the first update immediately,
// then debounces subsequent rapid updates, ensuring only the latest update is sent
// after a quiet period.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
// Create a debouncer for this peer connection
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
defer debouncer.Stop()
for {
select {
// condition when there are some updates
// todo set the updates channel size to 1
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
@@ -425,38 +410,20 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
}
// Timer expired - quiet period reached, send pending updates if any
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
continue
}
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
for _, pendingUpdate := range pendingUpdates {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return srv.Context().Err()
}
}
@@ -464,16 +431,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.Send(&proto.EncryptedMessage{
@@ -481,7 +448,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
@@ -513,15 +480,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
return nil
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}

Some files were not shown because too many files have changed in this diff Show More