Compare commits

...

19 Commits

Author SHA1 Message Date
Zoltán Papp
2ea7fb7b21 add support for disabling relay via environment variable 2026-02-18 18:59:56 +01:00
Zoltán Papp
93f530637d log duration of updateChecksIfNew execution 2026-02-18 18:43:48 +01:00
Zoltán Papp
3a84475d14 add test for adding multiple allowed IPs to userspace interface 2026-02-18 18:39:27 +01:00
Zoltán Papp
3d7368e51f set readyChan in Engine to signal when first sync is received 2026-02-18 15:58:05 +01:00
Zoltan Papp
318cf59d66 [relay] reduce QUIC initial packet size to 1280 (IPv6 min MTU) (#5374)
* [relay] reduce QUIC initial packet size to 1280 (IPv6 min MTU)

* adjust QUIC initial packet size to 1232 based on RFC 9000 §14
2026-02-18 10:58:14 +01:00
Pascal Fischer
e9b2a6e808 [managment] add flag to disable the old legacy grpc endpoint (#5372) 2026-02-17 19:53:14 +01:00
Zoltan Papp
2dbdb5c1a7 [client] Refactor WG endpoint setup with role-based proxy activation (#5277)
* Refactor WG endpoint setup with role-based proxy activation

For relay connections, the controller (initiator) now activates the
wgProxy before configuring the WG endpoint, while the non-controller
(responder) configures the endpoint first with a delayed update, then
activates the proxy after. This prevents the responder from sending
traffic through the proxy before WireGuard is ready to receive it,
avoiding handshake congestion when both sides try to initiate
simultaneously.

For ICE connections, pass hasRelayBackup as the setEndpointNow flag
so the responder sets the endpoint immediately when a relay fallback
exists (avoiding the delayed update path since relay is already
available as backup).

On ICE disconnect with relay fallback, remove the duplicate
wgProxyRelay.Work() calls — the relay proxy is already active from
initial setup, so re-activating it is unnecessary.

In EndpointUpdater, split ConfigureWGEndpoint into explicit
configureAsInitiator and configureAsResponder paths, and add the
setEndpointNow parameter to let the caller control whether the
responder applies the endpoint immediately or defers it. Add unused
SwitchWGEndpoint and RemoveEndpointAddress methods. Remove the
wgConfigWorkaround sleep from the relay setup path.

* Fix redundant wgProxyRelay.Work() call during relay fallback setup

* Simplify WireGuard endpoint configuration by removing unused parameters and redundant logic
2026-02-17 19:28:26 +01:00
Pascal Fischer
2cdab6d7b7 [proxy] remove unused oidc config flags (#5369) 2026-02-17 18:04:30 +01:00
Diego Noguês
e49c0e8862 [infrastructure] Proxy infra changes (#5365)
* chore: remove docker extra_hosts settings

* chore: remove unnecessary envc from proxy.env
2026-02-17 17:37:44 +01:00
Misha Bragin
e7c84d0ead Start Management if external IdP is down (#5367)
Set ContinueOnConnectorFailure: true in the embedded Dex config so that the Management server starts successfully even when an external IdP connector is unreachable at boot time.
2026-02-17 16:08:41 +01:00
Zoltan Papp
1c934cca64 Ignore false lint alert (#5370) 2026-02-17 16:07:35 +01:00
Vlad
4aff4a6424 [management] fix utc difference on last seen status for a peer (#5348) 2026-02-17 13:29:32 +01:00
Zoltan Papp
1bd7190954 [proxy] Support WebSocket (#5312)
* Fix WebSocket support by implementing Hijacker interface

Add responsewriter.PassthroughWriter to preserve optional HTTP interfaces
(Hijacker, Flusher, Pusher) when wrapping http.ResponseWriter in middleware.

Without this delegation:
 - WebSocket connections fail (can't hijack the connection)
 - Streaming breaks (can't flush buffers)
 - HTTP/2 push doesn't work

* Add HijackTracker to manage hijacked connections during graceful shutdown

* Refactor HijackTracker to use middleware for tracking hijacked connections

* Refactor server handler chain setup for improved readability and maintainability
2026-02-17 12:53:34 +01:00
Viktor Liu
0146e39714 Add listener side proxy protocol support and enable it in traefik (#5332)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2026-02-16 23:40:10 +01:00
Zoltan Papp
baed6e46ec Reset WireGuard endpoint on ICE session change during relay fallback (#5283)
When an ICE connection disconnects and falls back to relay, reset the
WireGuard endpoint and handshake watcher if the remote peer's ICE session
has changed. This ensures the controller re-establishes a fresh WireGuard
handshake rather than waiting on a stale endpoint from the previous session.
2026-02-16 20:59:29 +01:00
Maycon Santos
0d1ffba75f [misc] add additional cname example (#5341) 2026-02-16 13:30:58 +01:00
Diego Romar
1024d45698 [mobile] Export lazy connection environment variables for mobile clients (#5310)
* [client] Export lazy connection env vars

Both for Android and iOS

* [client] Separate comments
2026-02-16 09:04:45 -03:00
Zoltan Papp
e5d4947d60 [client] Optimize Windows DNS performance with domain batching and batch mode (#5264)
* Optimize Windows DNS performance with domain batching and batch mode

Implement two-layer optimization to reduce Windows NRPT registry operations:

1. Domain Batching (host_windows.go):
  - Batch domains per NRPT
  - Reduces NRPT rules by ~97% (e.g., 184 domains: 184 rules → 4 rules)
  - Modified addDNSMatchPolicy() to create batched NRPT entries
  - Added comprehensive tests in host_windows_test.go

2. Batch Mode (server.go):
  - Added BeginBatch/EndBatch methods to defer DNS updates
  - Modified RegisterHandler/DeregisterHandler to skip applyHostConfig in batch mode
  - Protected all applyHostConfig() calls with batch mode checks
  - Updated route manager to wrap route operations with batch calls

* Update tests

* Fix log line

* Fix NRPT rule index to ensure cleanup covers partially created rules

* Ensure NRPT entry count updates even on errors to improve cleanup reliability

* Switch DNS batch mode logging from Info to Debug level

* Fix batch mode to not suppress critical DNS config updates

Batch mode should only defer applyHostConfig() for RegisterHandler/
DeregisterHandler operations. Management updates and upstream nameserver
failures (deactivate/reactivate callbacks) need immediate DNS config
updates regardless of batch mode to ensure timely failover.

Without this fix, if a nameserver goes down during a route update,
the system DNS config won't be updated until EndBatch(), potentially
delaying failover by several seconds.

Or if you prefer a shorter version:

Fix batch mode to allow immediate DNS updates for critical paths

Batch mode now only affects RegisterHandler/DeregisterHandler.
Management updates and nameserver failures always trigger immediate
DNS config updates to ensure timely failover.

* Add DNS batch cancellation to rollback partial changes on errors

Introduces CancelBatch() method to the DNS server interface to handle error
scenarios during batch operations. When route updates fail partway through, the DNS
server can now discard accumulated changes instead of applying partial state. This
prevents leaving the DNS configuration in an inconsistent state when route manager
operations encounter errors.

The changes add error-aware batch handling to prevent partial DNS configuration
updates when route operations fail, which improves system reliability.
2026-02-15 22:10:26 +01:00
Maycon Santos
cb9b39b950 [misc] add extra proxy domain instructions (#5328)
improve proxy domain instructions
expose wireguard port
2026-02-15 12:51:46 +01:00
40 changed files with 1165 additions and 263 deletions

View File

@@ -1,10 +1,19 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
var (
// EnvKeyNBForceRelay Exported for Android java client
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
)
// EnvList wraps a Go map for export to Java

View File

@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {

View File

@@ -589,6 +589,101 @@ func Test_ConnectPeers(t *testing.T) {
}
func Test_UserSpaceAddAllowedIPs(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+5)
wgIP := "10.99.99.21/30"
wgPort := 33105
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = iface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
if err := iface.Close(); err != nil {
t.Error(err)
}
}()
_, err = iface.Up()
if err != nil {
t.Fatal(err)
}
keepAlive := 15 * time.Second
initialAllowedIP := netip.MustParsePrefix("10.99.99.22/32")
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9905")
if err != nil {
t.Fatal(err)
}
// Add peer with initial endpoint and first allowed IP
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{initialAllowedIP}, keepAlive, endpoint, nil)
if err != nil {
t.Fatal(err)
}
// Phase 1: generate 500 allowed IPs into a list
const extraIPs = 500
addedPrefixes := make([]netip.Prefix, 0, extraIPs)
for i := 0; i < extraIPs; i++ {
// Use 172.16.x.y/32 range: i encoded as two octets
prefix := netip.MustParsePrefix(fmt.Sprintf("172.16.%d.%d/32", i/256, i%256))
addedPrefixes = append(addedPrefixes, prefix)
}
// Phase 2: iterate over the list and add each allowed IP to the peer
phase2Start := time.Now()
for _, prefix := range addedPrefixes {
if addErr := iface.AddAllowedIP(peerPubKey, prefix); addErr != nil {
t.Fatalf("failed to add allowed IP %s: %v", prefix, addErr)
}
}
t.Logf("Phase 2 (add %d IPs to peer): %s", extraIPs, time.Since(phase2Start))
// Verify the peer has all 101 allowed IPs (1 initial + 100 added)
peer, err := getPeer(ifaceName, peerPubKey)
if err != nil {
t.Fatal(err)
}
if peer.Endpoint.String() != endpoint.String() {
t.Fatalf("expected endpoint %s, got %s", endpoint, peer.Endpoint)
}
allExpected := append([]netip.Prefix{initialAllowedIP}, addedPrefixes...)
if len(peer.AllowedIPs) != len(allExpected) {
t.Fatalf("expected %d allowed IPs, got %d", len(allExpected), len(peer.AllowedIPs))
}
allowedIPSet := make(map[string]struct{}, len(peer.AllowedIPs))
for _, aip := range peer.AllowedIPs {
allowedIPSet[aip.String()] = struct{}{}
}
for _, expected := range allExpected {
if _, found := allowedIPSet[expected.String()]; !found {
t.Errorf("expected allowed IP %s not found in peer config", expected)
}
}
}
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {

View File

@@ -290,6 +290,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if relayClient.IsDisableRelay() {
relayURLs = []string{}
}
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 {
@@ -310,6 +314,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
engine.SetReadyChan(runningChan)
runningChan = nil
c.engine = engine
c.engineMutex.Unlock()
@@ -330,11 +336,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
close(runningChan)
runningChan = nil
}
<-engineCtx.Done()
c.engineMutex.Lock()

View File

@@ -42,6 +42,8 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
@@ -198,10 +200,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
// Update count even on error to ensure cleanup covers partially created rules
r.nrptEntryCount = count
if err != nil {
return fmt.Errorf("add dns match policy: %w", err)
}
r.nrptEntryCount = count
} else {
r.nrptEntryCount = 0
}
@@ -239,23 +242,33 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain}
// We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
}
batchDomains := domains[i:end]
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
}
// Increment immediately so the caller's cleanup path knows about this rule
ruleIndex++
if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
}
}
log.Debugf("added NRPT entry for domain: %s", domain)
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
}
if r.gpo {
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return len(domains), nil
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
return ruleIndex, nil
}
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -12,6 +12,7 @@ import (
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
@@ -37,51 +38,60 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
domains125 := make([]DomainConfig, 125)
for i := 0; i < 125; i++ {
domains125[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
err = cfg.applyDNSConfig(config5, nil)
config125 := HostDNSConfig{
ServerIP: testIP,
Domains: domains125,
}
err = cfg.applyDNSConfig(config125, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
// Verify 3 NRPT rules exist
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
for i := 0; i < 3; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
}
config2 := HostDNSConfig{
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
domains75 := make([]DomainConfig, 75)
for i := 0; i < 75; i++ {
domains75[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config75 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
Domains: domains75,
}
err = cfg.applyDNSConfig(config2, nil)
err = cfg.applyDNSConfig(config75, nil)
require.NoError(t, err)
// Verify first 2 entries exist
// Verify first 2 NRPT rules exist
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
// Verify rule 2 is cleaned up
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
require.NoError(t, err)
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
}
func registryKeyExists(path string) (bool, error) {
@@ -97,6 +107,106 @@ func registryKeyExists(path string) (bool, error) {
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
// Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
_ = cfg.removeDNSMatchPolicies()
}
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -84,3 +84,18 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}
// CancelBatch mock implementation of CancelBatch from Server interface
func (m *MockServer) CancelBatch() {
// Mock implementation - no-op
}

View File

@@ -45,6 +45,9 @@ type IosDnsManager interface {
type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
CancelBatch()
Initialize() error
Stop()
DnsIP() netip.Addr
@@ -87,6 +90,7 @@ type DefaultServer struct {
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver
@@ -234,7 +238,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
s.applyHostConfig()
if !s.batchMode {
s.applyHostConfig()
}
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -263,9 +269,41 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone)
}
}
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig()
}
// CancelBatch cancels batch mode without applying accumulated changes.
// This is useful when operations fail partway through and you want to
// discard partial state rather than applying it.
func (s *DefaultServer) CancelBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
s.batchMode = false
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
@@ -523,6 +561,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false
}
// Always apply host config for management updates, regardless of batch mode
s.applyHostConfig()
s.shutdownWg.Add(1)
@@ -887,6 +926,7 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
// Always apply host config when nameserver goes down, regardless of batch mode
s.applyHostConfig()
go func() {
@@ -922,6 +962,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
}
// Always apply host config when nameserver reactivates, regardless of batch mode
s.applyHostConfig()
s.updateNSState(nsGroup, nil, true)

View File

@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err)
}
if srvB != srv {
mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
t.Errorf("mismatch dns instances")
}
}

View File

@@ -28,8 +28,8 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
@@ -217,6 +217,10 @@ type Engine struct {
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
// readyChan is closed when the first sync message is received from management
readyChan chan struct{}
readyChanOnce sync.Once
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup
@@ -275,6 +279,10 @@ func NewEngine(
return engine
}
func (e *Engine) SetReadyChan(ch chan struct{}) {
e.readyChan = ch
}
func (e *Engine) Stop() error {
if e == nil {
// this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
@@ -834,6 +842,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.readyChanOnce.Do(func() {
if e.readyChan != nil {
close(e.readyChan)
}
})
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -880,9 +895,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal
}
uCheckTime := time.Now()
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
log.Infof("update check finished in %s", time.Since(uCheckTime))
nm := update.GetNetworkMap()
if nm == nil {
@@ -925,7 +942,9 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(update.Urls)
if !relayClient.IsDisableRelay() {
e.relayManager.UpdateServerURLs(update.Urls)
}
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.

View File

@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
}
func (conn *Conn) onICEStateDisconnected() {
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -430,14 +430,18 @@ func (conn *Conn) onICEStateDisconnected() {
if conn.isReadyToUpgrade() {
conn.Log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay()
if sessionChanged {
conn.resetEndpoint()
}
// todo consider to move after the ConfigureWGEndpoint
conn.wgProxyRelay.Work()
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
@@ -499,20 +503,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
controller := isController(conn.config)
if controller {
wgProxy.Work()
}
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
}
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
wgConfigWorkaround()
if !controller {
wgProxy.Work()
}
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()
@@ -757,6 +763,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
return wgProxy, nil
}
func (conn *Conn) resetEndpoint() {
if !isController(conn.config) {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}
}
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
}
@@ -862,9 +879,3 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
func wgConfigWorkaround() {
time.Sleep(100 * time.Millisecond)
}

View File

@@ -34,28 +34,27 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
}
}
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.initiator {
e.log.Debugf("configure up WireGuard as initiatr")
return e.updateWireGuardPeer(addr, presharedKey)
e.log.Debugf("configure up WireGuard as initiator")
return e.configureAsInitiator(addr, presharedKey)
}
e.log.Debugf("configure up WireGuard as responder")
return e.configureAsResponder(addr, presharedKey)
}
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
e.mu.Lock()
defer e.mu.Unlock()
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
e.log.Debugf("configure up WireGuard and wait for handshake")
return e.updateWireGuardPeer(nil, presharedKey)
return e.updateWireGuardPeer(addr, presharedKey)
}
func (e *EndpointUpdater) RemoveWgPeer() error {
@@ -66,6 +65,38 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) RemoveEndpointAddress() error {
e.mu.Lock()
defer e.mu.Unlock()
e.waitForCloseTheDelayedUpdate()
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
}
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
return err
}
return nil
}
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
// prevent to run new update while cancel the previous update
e.waitForCloseTheDelayedUpdate()
e.log.Debugf("configure up WireGuard and wait for handshake")
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
e.updateWg.Add(1)
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
e.waitForCloseTheDelayedUpdate()
return err
}
return nil
}
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
if e.cancelFunc == nil {
return
@@ -101,3 +132,9 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
presharedKey,
)
}
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
func wgConfigWorkaround() {
time.Sleep(100 * time.Millisecond)
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
)
const (
@@ -125,6 +126,10 @@ func GenerateICECredentials() (string, string, error) {
}
func CandidateTypes() []ice.CandidateType {
if relayClient.IsDisableRelay() {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}

View File

@@ -32,6 +32,8 @@ type WGWatcher struct {
enabled bool
muEnabled sync.RWMutex
resetCh chan struct{}
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
wgIfaceStater: wgIfaceStater,
peerKey: peerKey,
stateDump: stateDump,
resetCh: make(chan struct{}, 1),
}
}
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
return w.enabled
}
// Reset signals the watcher that the WireGuard peer has been reset and a new
// handshake is expected. This restarts the handshake timeout from scratch.
func (w *WGWatcher) Reset() {
select {
case w.resetCh <- struct{}{}:
default:
}
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-w.resetCh:
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
lastHandshake = time.Time{}
enabledTime = time.Now()
timer.Stop()
timer.Reset(wgHandshakeOvertime)
case <-ctx.Done():
w.log.Infof("WireGuard watcher stopped")
return

View File

@@ -52,8 +52,9 @@ type WorkerICE struct {
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
muxAgent sync.Mutex
sessionID ICESessionID
remoteSessionChanged bool
muxAgent sync.Mutex
localUfrag string
localPwd string
@@ -106,6 +107,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.remoteSessionChanged = true
w.agentDialerCancel()
if w.agent != nil {
if err := w.agent.Close(); err != nil {
@@ -306,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
cancel()
if err := agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
sessionChanged := w.remoteSessionChanged
w.remoteSessionChanged = false
if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
@@ -325,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
w.agentConnecting = false
w.remoteSessionID = ""
}
w.muxAgent.Unlock()
return sessionChanged
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -426,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
sessionChanged := w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
w.conn.onICEStateDisconnected(sessionChanged)
}
default:
return

View File

@@ -346,6 +346,23 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
}
var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
batchStarted := false
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
batchStarted = true
defer func() {
if merr != nil {
// On error, cancel batch to discard partial DNS state
m.dnsServer.CancelBatch()
} else {
// On success, apply accumulated DNS changes
m.dnsServer.EndBatch()
}
}()
}
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
@@ -376,6 +393,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
m.activeRoutes[id] = handler
}
_ = batchStarted // Mark as used
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -2,7 +2,10 @@
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/peer"
)
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
@@ -32,3 +35,13 @@ func (el *EnvList) AllItems() map[string]string {
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
func GetEnvKeyNBLazyConn() string {
return lazyconn.EnvEnableLazyConn
}
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
func GetEnvKeyNBInactivityThreshold() string {
return lazyconn.EnvInactivityThreshold
}

View File

@@ -488,15 +488,17 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtConfig,
dnsDomain,
singleAccModeDomain,
mgmtPort,
cfg.Server.MetricsPort,
mgmt.DisableAnonymousMetrics,
mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
true,
&mgmtServer.Config{
NbConfig: mgmtConfig,
DNSDomain: dnsDomain,
MgmtSingleAccModeDomain: singleAccModeDomain,
MgmtPort: mgmtPort,
MgmtMetricsPort: cfg.Server.MetricsPort,
DisableMetrics: mgmt.DisableAnonymousMetrics,
DisableGeoliteUpdate: mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
UserDeleteFromIDPEnabled: true,
},
)
return mgmtSrv, nil

3
go.mod
View File

@@ -40,7 +40,7 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.13
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1
github.com/creack/pty v1.1.24
@@ -83,6 +83,7 @@ require (
github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.1.1
github.com/pion/turn/v3 v3.0.1
github.com/pires/go-proxyproto v0.11.0
github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2
github.com/quic-go/quic-go v0.55.0

6
go.sum
View File

@@ -107,8 +107,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -474,6 +474,8 @@ github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8=
github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4=
github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -99,15 +99,16 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Build Dex server config - use Dex's types directly
dexConfig := server.Config{
Issuer: issuer,
Storage: stor,
SkipApprovalScreen: true,
SupportedResponseTypes: []string{"code"},
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
RotateKeysAfter: 6 * time.Hour,
IDTokensValidFor: 24 * time.Hour,
RefreshTokenPolicy: refreshPolicy,
Issuer: issuer,
Storage: stor,
SkipApprovalScreen: true,
SupportedResponseTypes: []string{"code"},
ContinueOnConnectorFailure: true,
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
RotateKeysAfter: 6 * time.Hour,
IDTokensValidFor: 24 * time.Hour,
RefreshTokenPolicy: refreshPolicy,
Web: server.WebConfig{
Issuer: "NetBird",
},
@@ -260,6 +261,7 @@ func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.L
if len(cfg.SupportedResponseTypes) == 0 {
cfg.SupportedResponseTypes = []string{"code"}
}
cfg.ContinueOnConnectorFailure = true
return cfg
}

View File

@@ -2,6 +2,7 @@ package dex
import (
"context"
"log/slog"
"os"
"path/filepath"
"testing"
@@ -195,3 +196,64 @@ enablePasswordDB: true
t.Logf("User lookup successful: rawID=%s, connectorID=%s", rawID, connID)
}
func TestNewProvider_ContinueOnConnectorFailure(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-connector-failure-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &Config{
Issuer: "http://localhost:5556/dex",
Port: 5556,
DataDir: tmpDir,
}
provider, err := NewProvider(ctx, config)
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
// The provider should have started successfully even though
// ContinueOnConnectorFailure is an internal Dex config field.
// We verify the provider is functional by performing a basic operation.
assert.NotNil(t, provider.dexServer)
assert.NotNil(t, provider.storage)
}
func TestBuildDexConfig_ContinueOnConnectorFailure(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "dex-build-config-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
yamlContent := `
issuer: http://localhost:5556/dex
storage:
type: sqlite3
config:
file: ` + filepath.Join(tmpDir, "dex.db") + `
web:
http: 127.0.0.1:5556
enablePasswordDB: true
`
configPath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configPath, []byte(yamlContent), 0644)
require.NoError(t, err)
yamlConfig, err := LoadConfig(configPath)
require.NoError(t, err)
ctx := context.Background()
stor, err := yamlConfig.Storage.OpenStorage(slog.New(slog.NewTextHandler(os.Stderr, nil)))
require.NoError(t, err)
defer stor.Close()
err = initializeStorage(ctx, stor, yamlConfig)
require.NoError(t, err)
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
cfg := buildDexConfig(yamlConfig, stor, logger)
assert.True(t, cfg.ContinueOnConnectorFailure,
"buildDexConfig must set ContinueOnConnectorFailure to true so management starts even if an external IdP is down")
}

View File

@@ -183,14 +183,14 @@ read_enable_proxy() {
}
read_proxy_domain() {
local suggested_proxy="proxy.${NETBIRD_DOMAIN}"
local suggested_proxy="proxy.${BASE_DOMAIN}"
echo "" > /dev/stderr
echo "NOTE: The proxy domain must be different from the management domain ($NETBIRD_DOMAIN)" > /dev/stderr
echo "to avoid TLS certificate conflicts." > /dev/stderr
echo "" > /dev/stderr
echo "You also need to add a wildcard DNS record for the proxy domain," > /dev/stderr
echo "e.g. *.${suggested_proxy} pointing to the same server IP as $NETBIRD_DOMAIN." > /dev/stderr
echo "e.g. *.${suggested_proxy} pointing to the same server domain as $NETBIRD_DOMAIN with a CNAME record." > /dev/stderr
echo "" > /dev/stderr
echo -n "Enter the domain for the NetBird Proxy (e.g. ${suggested_proxy}): " > /dev/stderr
read -r READ_PROXY_DOMAIN < /dev/tty
@@ -202,13 +202,16 @@ read_proxy_domain() {
fi
if [[ "$READ_PROXY_DOMAIN" == "$NETBIRD_DOMAIN" ]]; then
echo "The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
echo "" > /dev/stderr
echo "WARNING: The proxy domain cannot be the same as the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
read_proxy_domain
return
fi
if [[ "$READ_PROXY_DOMAIN" == *".${NETBIRD_DOMAIN}" ]]; then
echo "The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
echo ${READ_PROXY_DOMAIN} | grep ${NETBIRD_DOMAIN} > /dev/null
if [[ $? -eq 0 ]]; then
echo "" > /dev/stderr
echo "WARNING: The proxy domain cannot be a subdomain of the management domain ($NETBIRD_DOMAIN)." > /dev/stderr
read_proxy_domain
return
fi
@@ -326,6 +329,9 @@ initialize_default_values() {
BIND_LOCALHOST_ONLY="true"
EXTERNAL_PROXY_NETWORK=""
# Traefik static IP within the internal bridge network
TRAEFIK_IP="172.30.0.10"
# NetBird Proxy configuration
ENABLE_PROXY="false"
PROXY_DOMAIN=""
@@ -340,10 +346,12 @@ configure_domain() {
if [[ "$NETBIRD_DOMAIN" == "use-ip" ]]; then
NETBIRD_DOMAIN=$(get_main_ip_address)
BASE_DOMAIN=$NETBIRD_DOMAIN
else
NETBIRD_PORT=443
NETBIRD_HTTP_PROTOCOL="https"
NETBIRD_RELAY_PROTO="rels"
BASE_DOMAIN=$(echo $NETBIRD_DOMAIN | sed -E 's/^[^.]+\.//')
fi
return 0
}
@@ -388,7 +396,7 @@ check_existing_installation() {
echo "Generated files already exist, if you want to reinitialize the environment, please remove them first."
echo "You can use the following commands:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes"
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt"
echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard."
exit 1
fi
@@ -407,6 +415,8 @@ generate_configuration_files() {
# This will be overwritten with the actual token after netbird-server starts
echo "# Placeholder - will be updated with token after netbird-server starts" > proxy.env
echo "NB_PROXY_TOKEN=placeholder" >> proxy.env
# TCP ServersTransport for PROXY protocol v2 to the proxy backend
render_traefik_dynamic > traefik-dynamic.yaml
fi
;;
1)
@@ -554,18 +564,21 @@ init_environment() {
############################################
render_docker_compose_traefik_builtin() {
# Generate proxy service section if enabled
# Generate proxy service section and Traefik dynamic config if enabled
local proxy_service=""
local proxy_volumes=""
local traefik_file_provider=""
local traefik_dynamic_volume=""
if [[ "$ENABLE_PROXY" == "true" ]]; then
traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"'
traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro"
proxy_service="
# NetBird Proxy - exposes internal resources to the internet
proxy:
image: $NETBIRD_PROXY_IMAGE
container_name: netbird-proxy
# Hairpin NAT fix: route domain back to traefik's static IP within Docker
extra_hosts:
- \"$NETBIRD_DOMAIN:172.30.0.10\"
ports:
- 51820:51820/udp
restart: unless-stopped
networks: [netbird]
depends_on:
@@ -583,6 +596,7 @@ render_docker_compose_traefik_builtin() {
- traefik.tcp.routers.proxy-passthrough.service=proxy-tls
- traefik.tcp.routers.proxy-passthrough.priority=1
- traefik.tcp.services.proxy-tls.loadbalancer.server.port=8443
- traefik.tcp.services.proxy-tls.loadbalancer.serverstransport=pp-v2@file
logging:
driver: \"json-file\"
options:
@@ -602,7 +616,7 @@ services:
restart: unless-stopped
networks:
netbird:
ipv4_address: 172.30.0.10
ipv4_address: $TRAEFIK_IP
command:
# Logging
- "--log.level=INFO"
@@ -629,12 +643,14 @@ services:
# gRPC transport settings
- "--serverstransport.forwardingtimeouts.responseheadertimeout=0s"
- "--serverstransport.forwardingtimeouts.idleconntimeout=0s"
$traefik_file_provider
ports:
- '443:443'
- '80:80'
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- netbird_traefik_letsencrypt:/letsencrypt
$traefik_dynamic_volume
logging:
driver: "json-file"
options:
@@ -744,6 +760,10 @@ server:
cliRedirectURIs:
- "http://localhost:53000/"
reverseProxy:
trustedHTTPProxies:
- "$TRAEFIK_IP/32"
store:
engine: "sqlite"
encryptionKey: "$DATASTORE_ENCRYPTION_KEY"
@@ -773,6 +793,17 @@ EOF
return 0
}
render_traefik_dynamic() {
cat <<'EOF'
tcp:
serversTransports:
pp-v2:
proxyProtocol:
version: 2
EOF
return 0
}
render_proxy_env() {
cat <<EOF
# NetBird Proxy Configuration
@@ -788,10 +819,11 @@ NB_PROXY_TOKEN=$PROXY_TOKEN
NB_PROXY_CERTIFICATE_DIRECTORY=/certs
NB_PROXY_ACME_CERTIFICATES=true
NB_PROXY_ACME_CHALLENGE_TYPE=tls-alpn-01
NB_PROXY_OIDC_CLIENT_ID=netbird-proxy
NB_PROXY_OIDC_ENDPOINT=$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/oauth2
NB_PROXY_OIDC_SCOPES=openid,profile,email
NB_PROXY_FORWARDED_PROTO=https
# Enable PROXY protocol to preserve client IPs through L4 proxies (Traefik TCP passthrough)
NB_PROXY_PROXY_PROTOCOL=true
# Trust Traefik's IP for PROXY protocol headers
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF
return 0
}
@@ -1150,23 +1182,30 @@ print_builtin_traefik_instructions() {
echo " NETBIRD SETUP COMPLETE"
echo "$MSG_SEPARATOR"
echo ""
echo "You can access the NetBird dashboard at $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo "You can access the NetBird dashboard at:"
echo " $NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
echo ""
echo "Follow the onboarding steps to set up your NetBird instance."
echo ""
echo "Traefik is handling TLS certificates automatically via Let's Encrypt."
echo "If you see certificate warnings, wait a moment for certificate issuance to complete."
echo ""
echo "Open ports:"
echo " - 443/tcp (HTTPS - all NetBird services)"
echo " - 80/tcp (HTTP - redirects to HTTPS)"
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
echo " - 443/tcp (HTTPS - all NetBird services)"
echo " - 80/tcp (HTTP - redirects to HTTPS)"
echo " - $NETBIRD_STUN_PORT/udp (STUN - required for NAT traversal)"
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
echo ""
echo "NetBird Proxy:"
echo " The proxy service is enabled and running."
echo " Any domain NOT matching $NETBIRD_DOMAIN will be passed through to the proxy."
echo " The proxy handles its own TLS certificates via ACME TLS-ALPN-01 challenge."
echo " Point your proxy domains (CNAMEs) to this server's IP address."
echo " Point your proxy domain to this server's domain address like in the examples below:"
echo ""
echo " $PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
echo " *.$PROXY_DOMAIN CNAME $NETBIRD_DOMAIN"
echo ""
fi
return 0
}

View File

@@ -29,11 +29,11 @@ import (
"github.com/netbirdio/netbird/util/crypt"
)
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
var newServer = func(cfg *server.Config) server.Server {
return server.NewServer(cfg)
}
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
func SetNewServer(fn func(*server.Config) server.Server) {
newServer = fn
}
@@ -110,7 +110,17 @@ var (
mgmtSingleAccModeDomain = ""
}
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
srv := newServer(&server.Config{
NbConfig: config,
DNSDomain: dnsDomain,
MgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
MgmtPort: mgmtPort,
MgmtMetricsPort: mgmtMetricsPort,
DisableLegacyManagementPort: disableLegacyManagementPort,
DisableMetrics: disableMetrics,
DisableGeoliteUpdate: disableGeoliteUpdate,
UserDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
})
go func() {
if err := srv.Start(cmd.Context()); err != nil {
log.Fatalf("Server error: %v", err)

View File

@@ -16,21 +16,22 @@ const (
)
var (
dnsDomain string
mgmtDataDir string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
disableGeoliteUpdate bool
idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
mgmtPort int
mgmtMetricsPort int
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
dnsDomain string
mgmtDataDir string
logLevel string
logFile string
disableMetrics bool
disableSingleAccMode bool
disableGeoliteUpdate bool
idpSignKeyRefreshEnabled bool
userDeleteFromIDPEnabled bool
mgmtPort int
mgmtMetricsPort int
disableLegacyManagementPort bool
mgmtLetsencryptDomain string
mgmtSingleAccModeDomain string
certFile string
certKey string
rootCmd = &cobra.Command{
Use: "netbird-mgmt",
@@ -55,6 +56,7 @@ func Execute() error {
func init() {
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().BoolVar(&disableLegacyManagementPort, "disable-legacy-port", false, "disabling the old legacy port (33073)")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")

View File

@@ -50,13 +50,14 @@ type BaseServer struct {
// AfterInit is a function that will be called after the server is initialized
afterInit []func(s *BaseServer)
disableMetrics bool
dnsDomain string
disableGeoliteUpdate bool
userDeleteFromIDPEnabled bool
mgmtSingleAccModeDomain string
mgmtMetricsPort int
mgmtPort int
disableMetrics bool
dnsDomain string
disableGeoliteUpdate bool
userDeleteFromIDPEnabled bool
mgmtSingleAccModeDomain string
mgmtMetricsPort int
mgmtPort int
disableLegacyManagementPort bool
proxyAuthClose func()
@@ -69,18 +70,32 @@ type BaseServer struct {
cancel context.CancelFunc
}
// Config holds the configuration parameters for creating a new server
type Config struct {
NbConfig *nbconfig.Config
DNSDomain string
MgmtSingleAccModeDomain string
MgmtPort int
MgmtMetricsPort int
DisableLegacyManagementPort bool
DisableMetrics bool
DisableGeoliteUpdate bool
UserDeleteFromIDPEnabled bool
}
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
func NewServer(cfg *Config) *BaseServer {
return &BaseServer{
Config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
disableMetrics: disableMetrics,
disableGeoliteUpdate: disableGeoliteUpdate,
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
mgmtPort: mgmtPort,
mgmtMetricsPort: mgmtMetricsPort,
Config: cfg.NbConfig,
container: make(map[string]any),
dnsDomain: cfg.DNSDomain,
mgmtSingleAccModeDomain: cfg.MgmtSingleAccModeDomain,
disableMetrics: cfg.DisableMetrics,
disableGeoliteUpdate: cfg.DisableGeoliteUpdate,
userDeleteFromIDPEnabled: cfg.UserDeleteFromIDPEnabled,
mgmtPort: cfg.MgmtPort,
disableLegacyManagementPort: cfg.DisableLegacyManagementPort,
mgmtMetricsPort: cfg.MgmtMetricsPort,
}
}
@@ -152,7 +167,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
if s.mgmtPort != ManagementLegacyPort && !s.disableLegacyManagementPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = s.serveGRPC(srvCtx, s.GRPCServer(), ManagementLegacyPort)

View File

@@ -224,6 +224,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(1)
reqStart := time.Now()
syncStart := reqStart.UTC()
ctx := srv.Context()
@@ -300,7 +301,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
@@ -311,7 +312,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
return err
}
@@ -319,7 +320,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart)
return err
}
@@ -336,7 +337,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {

View File

@@ -6,14 +6,14 @@ import (
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"github.com/netbirdio/netbird/shared/management/domain"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/proxy"
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/util"
@@ -46,16 +46,13 @@ var (
debugEndpoint bool
debugEndpointAddr string
healthAddr string
oidcClientID string
oidcClientSecret string
oidcEndpoint string
oidcScopes string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
wgPort int
proxyProtocol bool
)
var rootCmd = &cobra.Command{
@@ -80,16 +77,13 @@ func init() {
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
rootCmd.Flags().StringVar(&healthAddr, "health-addr", envStringOrDefault("NB_PROXY_HEALTH_ADDRESS", "localhost:8080"), "Address for the health probe endpoint (liveness/readiness/startup)")
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
}
// Execute runs the root command.
@@ -157,14 +151,11 @@ func runServer(cmd *cobra.Command, args []string) error {
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
HealthAddress: healthAddr,
OIDCClientId: oidcClientID,
OIDCClientSecret: oidcClientSecret,
OIDCEndpoint: oidcEndpoint,
OIDCScopes: strings.Split(oidcScopes, ","),
ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
WireguardPort: wgPort,
ProxyProtocol: proxyProtocol,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)

View File

@@ -9,6 +9,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
"github.com/netbirdio/netbird/proxy/web"
)
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
// Use a response writer wrapper so we can access the status code later.
sw := &statusWriter{
w: w,
status: http.StatusOK,
PassthroughWriter: responsewriter.New(w),
status: http.StatusOK,
}
// Resolve the source IP using trusted proxy configuration before passing

View File

@@ -1,26 +1,18 @@
package accesslog
import (
"net/http"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
)
// statusWriter is a simple wrapper around an http.ResponseWriter
// that captures the setting of the status code via the WriteHeader
// function and stores it so that it can be retrieved later.
// statusWriter captures the HTTP status code from WriteHeader calls.
// It embeds responsewriter.PassthroughWriter which handles all the optional
// interfaces (Hijacker, Flusher, Pusher) automatically.
type statusWriter struct {
w http.ResponseWriter
*responsewriter.PassthroughWriter
status int
}
func (w *statusWriter) Header() http.Header {
return w.w.Header()
}
func (w *statusWriter) Write(data []byte) (int, error) {
return w.w.Write(data)
}
func (w *statusWriter) WriteHeader(status int) {
w.status = status
w.w.WriteHeader(status)
w.PassthroughWriter.WriteHeader(status)
}

View File

@@ -0,0 +1,49 @@
package conntrack
import (
"bufio"
"net"
"net/http"
)
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
type trackedConn struct {
net.Conn
tracker *HijackTracker
}
func (c *trackedConn) Close() error {
c.tracker.conns.Delete(c)
return c.Conn.Close()
}
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
// to replace the raw connection with a trackedConn that auto-deregisters.
type trackingWriter struct {
http.ResponseWriter
tracker *HijackTracker
}
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, http.ErrNotSupported
}
conn, buf, err := hijacker.Hijack()
if err != nil {
return nil, nil, err
}
tc := &trackedConn{Conn: conn, tracker: w.tracker}
w.tracker.conns.Store(tc, struct{}{})
return tc, buf, nil
}
func (w *trackingWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *trackingWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@@ -0,0 +1,41 @@
package conntrack
import (
"net"
"net/http"
"sync"
)
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
// upgrades). http.Server.Shutdown does not close hijacked connections, so
// they must be tracked and closed explicitly during graceful shutdown.
//
// Use Middleware as the outermost HTTP middleware to ensure hijacked
// connections are tracked and automatically deregistered when closed.
type HijackTracker struct {
conns sync.Map // net.Conn → struct{}
}
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
// hijacked connections are tracked and automatically deregistered from the
// tracker when closed. This should be the outermost middleware in the chain.
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
})
}
// CloseAll closes all tracked hijacked connections and returns the number
// of connections that were closed.
func (t *HijackTracker) CloseAll() int {
var count int
t.conns.Range(func(key, _ any) bool {
if conn, ok := key.(net.Conn); ok {
_ = conn.Close()
count++
}
t.conns.Delete(key)
return true
})
return count
}

View File

@@ -5,9 +5,11 @@ import (
"strconv"
"time"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
)
type Metrics struct {
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
}
type responseInterceptor struct {
http.ResponseWriter
*responsewriter.PassthroughWriter
status int
size int
}
func (w *responseInterceptor) WriteHeader(status int) {
w.status = status
w.ResponseWriter.WriteHeader(status)
w.PassthroughWriter.WriteHeader(status)
}
func (w *responseInterceptor) Write(b []byte) (int, error) {
size, err := w.ResponseWriter.Write(b)
size, err := w.PassthroughWriter.Write(b)
w.size += size
return size, err
}
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
m.requestsTotal.Inc()
m.activeRequests.Inc()
interceptor := &responseInterceptor{ResponseWriter: w}
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
start := time.Now()
next.ServeHTTP(interceptor, r)

View File

@@ -0,0 +1,53 @@
package responsewriter
import (
"bufio"
"net"
"net/http"
)
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
// ResponseWriter if it supports them.
//
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
// and HTTP/2 server push.
type PassthroughWriter struct {
http.ResponseWriter
}
// New creates a new wrapper around the given ResponseWriter.
func New(w http.ResponseWriter) *PassthroughWriter {
return &PassthroughWriter{ResponseWriter: w}
}
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
// This is required for WebSocket connections and other protocol upgrades.
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, http.ErrNotSupported
}
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
func (w *PassthroughWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
return pusher.Push(target, opts)
}
return http.ErrNotSupported
}
// Unwrap returns the underlying ResponseWriter.
// This is required for http.ResponseController (Go 1.20+) to work correctly.
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

106
proxy/proxyprotocol_test.go Normal file
View File

@@ -0,0 +1,106 @@
package proxy
import (
"net"
"net/netip"
"testing"
"time"
proxyproto "github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
ProxyProtocol: true,
}
raw, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer raw.Close()
ln := srv.wrapProxyProtocol(raw)
realClientIP := "203.0.113.50"
realClientPort := uint16(54321)
accepted := make(chan net.Conn, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
accepted <- conn
}()
// Connect and send a PROXY v2 header.
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer conn.Close()
header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
}
_, err = header.WriteTo(conn)
require.NoError(t, err)
select {
case accepted := <-accepted:
defer accepted.Close()
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
require.NoError(t, err)
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for connection")
}
}
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
}
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
}
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
}

View File

@@ -23,6 +23,7 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pires/go-proxyproto"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
@@ -36,6 +37,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
@@ -63,6 +65,11 @@ type Server struct {
healthChecker *health.Checker
meter *metrics.Metrics
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them.
hijackTracker conntrack.HijackTracker
// Mostly used for debugging on management.
startTime time.Time
@@ -82,17 +89,13 @@ type Server struct {
ACMEChallengeType string
// CertLockMethod controls how ACME certificate locks are coordinated
// across replicas. Default: CertLockAuto (detect environment).
CertLockMethod acme.CertLockMethod
OIDCClientId string
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
CertLockMethod acme.CertLockMethod
// DebugEndpointEnabled enables the debug HTTP endpoint.
DebugEndpointEnabled bool
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
DebugEndpointAddress string
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
// HealthAddress is the address for the health probe endpoint.
HealthAddress string
// ProxyToken is the access token for authenticating with the management server.
ProxyToken string
@@ -107,6 +110,10 @@ type Server struct {
// random OS-assigned port. A fixed port only works with single-account
// deployments; multiple accounts will fail to bind the same port.
WireguardPort int
// ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners.
// When enabled, the real client IP is extracted from the PROXY header
// sent by upstream L4 proxies that support PROXY protocol.
ProxyProtocol bool
}
// NotifyStatus sends a status update to management about tunnel connectivity
@@ -137,23 +144,8 @@ func (s *Server) NotifyCertificateIssued(ctx context.Context, accountID, service
}
func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.startTime = time.Now()
s.initDefaults()
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
}
// Fallback version option in case it is not set.
if s.Version == "" {
s.Version = "dev"
}
// If no logger is specified fallback to the standard logger.
if s.Logger == nil {
s.Logger = log.StandardLogger()
}
// Start up metrics gathering
reg := prometheus.NewRegistry()
s.meter = metrics.New(reg)
@@ -189,53 +181,41 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
if s.DebugEndpointEnabled {
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
if s.acme != nil {
debugHandler.SetCertStatus(s.acme)
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
}
go func() {
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
s.startDebugEndpoint()
if err := s.startHealthServer(reg); err != nil {
return err
}
// Start health probe server.
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = "localhost:8080"
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
healthListener, err := net.Listen("tcp", healthAddr)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
// Build the handler chain from inside out.
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
handler = s.hijackTracker.Middleware(handler)
// Start the reverse proxy HTTPS server.
s.https = &http.Server{
Addr: addr,
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
Handler: handler,
TLSConfig: tlsConfig,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
}
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
if s.ProxyProtocol {
ln = s.wrapProxyProtocol(ln)
}
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debugf("starting reverse proxy server on %s", addr)
httpsErr <- s.https.ListenAndServeTLS("", "")
httpsErr <- s.https.ServeTLS(ln, "", "")
}()
select {
@@ -251,7 +231,115 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
}
}
// initDefaults sets fallback values for optional Server fields.
func (s *Server) initDefaults() {
s.startTime = time.Now()
// If no ID is set then one can be generated.
if s.ID == "" {
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
}
// Fallback version option in case it is not set.
if s.Version == "" {
s.Version = "dev"
}
// If no logger is specified fallback to the standard logger.
if s.Logger == nil {
s.Logger = log.StandardLogger()
}
}
// startDebugEndpoint launches the debug HTTP server if enabled.
func (s *Server) startDebugEndpoint() {
if !s.DebugEndpointEnabled {
return
}
debugAddr := debugEndpointAddr(s.DebugEndpointAddress)
debugHandler := debug.NewHandler(s.netbird, s.healthChecker, s.Logger)
if s.acme != nil {
debugHandler.SetCertStatus(s.acme)
}
s.debug = &http.Server{
Addr: debugAddr,
Handler: debugHandler,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueDebug),
}
go func() {
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
}
// startHealthServer launches the health probe and metrics server.
func (s *Server) startHealthServer(reg *prometheus.Registry) error {
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = defaultHealthAddr
}
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
healthListener, err := net.Listen("tcp", healthAddr)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
return nil
}
// wrapProxyProtocol wraps a listener with PROXY protocol support.
// When TrustedProxies is configured, only those sources may send PROXY headers;
// connections from untrusted sources have any PROXY header ignored.
func (s *Server) wrapProxyProtocol(ln net.Listener) net.Listener {
ppListener := &proxyproto.Listener{
Listener: ln,
ReadHeaderTimeout: proxyProtoHeaderTimeout,
}
if len(s.TrustedProxies) > 0 {
ppListener.ConnPolicy = s.proxyProtocolPolicy
} else {
s.Logger.Warn("PROXY protocol enabled without trusted proxies; any source may send PROXY headers")
}
s.Logger.Info("PROXY protocol enabled on listener")
return ppListener
}
// proxyProtocolPolicy returns whether to require, skip, or reject the PROXY
// header based on whether the connection source is in TrustedProxies.
func (s *Server) proxyProtocolPolicy(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
// No logging on reject to prevent abuse
tcpAddr, ok := opts.Upstream.(*net.TCPAddr)
if !ok {
return proxyproto.REJECT, nil
}
addr, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
return proxyproto.REJECT, nil
}
addr = addr.Unmap()
// called per accept
for _, prefix := range s.TrustedProxies {
if prefix.Contains(addr) {
return proxyproto.REQUIRE, nil
}
}
return proxyproto.IGNORE, nil
}
const (
defaultHealthAddr = "localhost:8080"
defaultDebugAddr = "localhost:8444"
// proxyProtoHeaderTimeout is the deadline for reading the PROXY protocol
// header after accepting a connection.
proxyProtoHeaderTimeout = 5 * time.Second
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
// before draining connections. This allows the load balancer to propagate
// the endpoint removal.
@@ -379,7 +467,12 @@ func (s *Server) gracefulShutdown() {
s.Logger.Warnf("https server drain: %v", err)
}
// Step 4: Stop all remaining background services.
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
if n := s.hijackTracker.CloseAll(); n > 0 {
s.Logger.Infof("closed %d hijacked connection(s)", n)
}
// Step 5: Stop all remaining background services.
s.shutdownServices()
s.Logger.Info("graceful shutdown complete")
}
@@ -647,7 +740,7 @@ func (s *Server) protoToMapping(mapping *proto.ProxyMapping) proxy.Mapping {
// If addr is empty, it defaults to localhost:8444 for security.
func debugEndpointAddr(addr string) string {
if addr == "" {
return "localhost:8444"
return defaultDebugAddr
}
return addr
}

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
nbRelay "github.com/netbirdio/netbird/shared/relay"
)
const Proto protocol.Protocol = "quic"
@@ -27,7 +28,7 @@ type Listener struct {
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1452,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
}
listener, err := quic.ListenAddr(l.Address, l.TLSConfig, quicCfg)
if err != nil {

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/client/net"
nbRelay "github.com/netbirdio/netbird/shared/relay"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
)
@@ -42,7 +43,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: 1452,
InitialPacketSize: nbRelay.QUICInitialPacketSize,
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})

View File

@@ -0,0 +1,15 @@
package client
import (
"os"
"strconv"
)
const (
envKeyNBDebugDisableRelay = "NB_DEBUG_DISABLE_RELAY"
)
func IsDisableRelay() bool {
v, _ := strconv.ParseBool(os.Getenv(envKeyNBDebugDisableRelay))
return v
}

View File

@@ -3,4 +3,9 @@ package relay
const (
// WebSocketURLPath is the path for the websocket relay connection
WebSocketURLPath = "/relay"
// QUICInitialPacketSize is the conservative initial QUIC packet size (bytes)
// for unknown-path PMTU, per RFC 9000 §14: 1280 (IPv6 min MTU) 40 (IPv6
// header) 8 (UDP header) = 1232. DPLPMTUD may probe larger sizes later.
QUICInitialPacketSize = 1232
)