Compare commits

..

6 Commits

Author SHA1 Message Date
Maycon Santos
26ed91652d Debug issue 2025-10-17 09:30:54 +02:00
Zoltan Papp
af95aabb03 Handle the case when the service has already been down and the status recorder is not available (#4652) 2025-10-16 17:15:39 +02:00
Viktor Liu
3abae0bd17 [client] Set default wg port for new profiles (#4651) 2025-10-16 16:16:51 +02:00
Viktor Liu
8252ff41db [client] Add bind activity listener to bypass udp sockets (#4646) 2025-10-16 15:58:29 +02:00
Viktor Liu
277aa2b7cc [client] Fix missing flag values in profiles (#4650) 2025-10-16 15:13:41 +02:00
John Conley
bb37dc89ce [management] feat: Basic PocketID IDP integration (#4529) 2025-10-16 10:46:29 +02:00
36 changed files with 1950 additions and 1647 deletions

View File

@@ -18,7 +18,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
install_binary: [true, true]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
@@ -35,3 +35,7 @@ jobs:
- name: check cli binary
run: command -v netbird
- name: out file
if: failure()
run: cat out.log

View File

@@ -23,4 +23,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -21,6 +21,7 @@ type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
EndpointManager
}
type TunNetstackDevice struct {
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -0,0 +1,13 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,4 +21,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind

View File

@@ -14,6 +14,9 @@ type WGIface interface {
}
func (g *BundleGenerator) addWgShow() error {
if g.statusRecorder == nil {
return fmt.Errorf("no status recorder available for wg show")
}
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err

View File

@@ -0,0 +1,82 @@
package activity
import (
"context"
"io"
"net"
"time"
)
// lazyConn detects activity when WireGuard attempts to send packets.
// It does not deliver packets, only signals that activity occurred.
type lazyConn struct {
activityCh chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// newLazyConn creates a new lazyConn for activity detection.
func newLazyConn() *lazyConn {
ctx, cancel := context.WithCancel(context.Background())
return &lazyConn{
activityCh: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
}
}
// Read blocks until the connection is closed.
func (c *lazyConn) Read(_ []byte) (n int, err error) {
<-c.ctx.Done()
return 0, io.EOF
}
// Write signals activity detection when ICEBind routes packets to this endpoint.
func (c *lazyConn) Write(b []byte) (n int, err error) {
if c.ctx.Err() != nil {
return 0, io.EOF
}
select {
case c.activityCh <- struct{}{}:
default:
}
return len(b), nil
}
// ActivityChan returns the channel that signals when activity is detected.
func (c *lazyConn) ActivityChan() <-chan struct{} {
return c.activityCh
}
// Close closes the connection.
func (c *lazyConn) Close() error {
c.cancel()
return nil
}
// LocalAddr returns the local address.
func (c *lazyConn) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// RemoteAddr returns the remote address.
func (c *lazyConn) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// SetDeadline sets the read and write deadlines.
func (c *lazyConn) SetDeadline(_ time.Time) error {
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@@ -0,0 +1,127 @@
package activity
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
type bindProvider interface {
GetBind() device.EndpointManager
}
const (
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
// The actual routing is done via fakeIP in ICEBind, not by this port.
lazyBindPort = 17473
)
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
type BindListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
done sync.WaitGroup
lazyConn *lazyConn
bind device.EndpointManager
fakeIP netip.Addr
}
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
if err != nil {
return nil, fmt.Errorf("derive fake IP: %w", err)
}
d := &BindListener{
wgIface: wgIface,
peerCfg: cfg,
bind: bind,
fakeIP: fakeIP,
}
if err := d.setupLazyConn(); err != nil {
return nil, fmt.Errorf("setup lazy connection: %v", err)
}
d.done.Add(1)
return d, nil
}
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
if len(allowedIPs) == 0 {
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
}
ourNetwork := wgIface.Address().Network
var peerIP netip.Addr
for _, allowedIP := range allowedIPs {
ip := allowedIP.Addr()
if !ip.Is4() {
continue
}
if ourNetwork.Contains(ip) {
peerIP = ip
break
}
}
if !peerIP.IsValid() {
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
}
octets := peerIP.As4()
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
return fakeIP, nil
}
func (d *BindListener) setupLazyConn() error {
d.lazyConn = newLazyConn()
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
endpoint := &net.UDPAddr{
IP: d.fakeIP.AsSlice(),
Port: lazyBindPort,
}
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
}
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
func (d *BindListener) ReadPackets() {
select {
case <-d.lazyConn.ActivityChan():
d.peerCfg.Log.Infof("activity detected via LazyConn")
case <-d.lazyConn.ctx.Done():
d.peerCfg.Log.Infof("exit from activity listener")
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
if err := d.lazyConn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
}
d.done.Wait()
}

View File

@@ -0,0 +1,291 @@
package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
}
func newMockEndpointManager() *mockEndpointManager {
return &mockEndpointManager{
endpoints: make(map[netip.Addr]net.Conn),
}
}
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
m.endpoints[fakeIP] = conn
}
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
delete(m.endpoints, fakeIP)
}
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
return m.endpoints[fakeIP]
}
// MockWGIfaceBind mocks WgInterface with bind support
type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
return true
}
func (m *MockWGIfaceBind) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
_, ok := conn.(*lazyConn)
assert.True(t, ok, "Registered endpoint should be a lazyConn")
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestBindListener_ActivityDetection(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
fakeIP := listener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
}
func TestBindListener_Close(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
fakeIP := listener.fakeIP
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
err := mgr.MonitorPeerActivity(cfg)
require.NoError(t, err)
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
require.True(t, exists, "Peer listener should be found")
bindListener, ok := listener.(*BindListener)
require.True(t, ok, "Listener should be BindListener, got %T", listener)
fakeIP := bindListener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer1 := &MocPeer{PeerID: "testPeer1"}
peer2 := &MocPeer{PeerID: "testPeer2"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
cfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
Log: log.WithField("peer", "testPeer2"),
}
err := mgr.MonitorPeerActivity(cfg1)
require.NoError(t, err)
err = mgr.MonitorPeerActivity(cfg2)
require.NoError(t, err)
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
require.True(t, exists, "Peer1 listener should be found")
bindListener1 := listener1.(*BindListener)
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
require.True(t, exists, "Peer2 listener should be found")
bindListener2 := listener2.(*BindListener)
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
_, err = conn1.Write([]byte{0x01})
require.NoError(t, err)
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
_, err = conn2.Write([]byte{0x02})
require.NoError(t, err)
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}
}
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
}

View File

@@ -1,41 +0,0 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -11,26 +11,27 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
// UDPListener uses UDP sockets for activity detection in kernel mode.
type UDPListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
isClosed atomic.Bool
}
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
// NewUDPListener creates a listener that detects activity via UDP socket reads.
func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
d := &UDPListener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
return nil, fmt.Errorf("create UDP connection: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
@@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
func (d *UDPListener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
@@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() {
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil {
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
// Ignore close error as it may return "use of closed network connection" if already closed.
_ = d.conn.Close()
d.done.Unlock()
}
func (d *Listener) Close() {
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
@@ -82,16 +87,12 @@ func (d *Listener) Close() {
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
func (d *UDPListener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
func (d *UDPListener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,

View File

@@ -0,0 +1,110 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestUDPListener_Creation(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
require.NotNil(t, listener.conn)
require.NotNil(t, listener.endpoint)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestUDPListener_ActivityDetection(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
}
func TestUDPListener_Close(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
}

View File

@@ -1,21 +1,32 @@
package activity
import (
"errors"
"net"
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
// listener defines the contract for activity detection listeners.
type listener interface {
ReadPackets()
Close()
}
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
}
type Manager struct {
@@ -23,7 +34,7 @@ type Manager struct {
wgIface WgInterface
peers map[peerid.ConnID]*Listener
peers map[peerid.ConnID]listener
done chan struct{}
mu sync.Mutex
@@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
}
return m
@@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
listener, err := m.createListener(peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
if !m.wgIface.IsUserspaceBind() {
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
}
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -82,8 +115,8 @@ func (m *Manager) Close() {
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
l.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
@@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error {
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
// Add this method to the Manager struct
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
func (m MocWGIface) IsUserspaceBind() bool {
return false
}
func (m MocWGIface) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
defer m.mu.Unlock()
listener, exists := m.peers[peerConnID]
return listener, exists
l, exists := m.peers[peerConnID]
return l, exists
}
func TestManager_MonitorPeerActivity(t *testing.T) {
@@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
t.Fatalf("peer listener not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
// Get the UDP listener's address for triggering
udpListener, ok := listener.(*UDPListener)
if !ok {
t.Fatalf("expected UDPListener")
}
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
udpListener, _ := listener.(*UDPListener)
addr := udpListener.conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
@@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer1 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
udpListener1, _ := listener.(*UDPListener)
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer2 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
udpListener2, _ := listener.(*UDPListener)
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}

View File

@@ -7,6 +7,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/monotime"
)
@@ -14,5 +15,6 @@ type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
}

View File

@@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
WgPort: iface.DefaultWgPort,
}
if _, err := config.apply(input); err != nil {

View File

@@ -5,11 +5,14 @@ import (
"errors"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/util"
)
@@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) {
}
}
func TestNewProfileDefaults(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
})
require.NoError(t, err, "should create new config")
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
}
}
func TestWireguardPortZeroExplicit(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
// Create a new profile with explicit port 0 (random port)
explicitZero := 0
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: &explicitZero,
})
require.NoError(t, err, "should create config with explicit port 0")
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
// Verify it persists
readConfig, err := GetConfig(configPath)
require.NoError(t, err)
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
}
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
tests := []struct {
name string
wireguardPort *int
expectedPort int
description string
}{
{
name: "no port specified uses default",
wireguardPort: nil,
expectedPort: iface.DefaultWgPort,
description: "When user doesn't specify port, default to 51820",
},
{
name: "explicit zero for random port",
wireguardPort: func() *int { v := 0; return &v }(),
expectedPort: 0,
description: "When user explicitly sets 0, use 0 for random port",
},
{
name: "explicit custom port",
wireguardPort: func() *int { v := 52000; return &v }(),
expectedPort: 52000,
description: "When user sets custom port, use that port",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: tt.wireguardPort,
})
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.CustomDNSAddress = []byte{}
}
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
if msg.DnsRouteInterval != nil {
interval := msg.DnsRouteInterval.AsDuration()
config.DNSRouteInterval = &interval
}
config.RosenpassEnabled = msg.RosenpassEnabled
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect

View File

@@ -0,0 +1,298 @@
package server
import (
"context"
"os/user"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
)
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
func TestSetConfig_AllFieldsSaved(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origDefaultConfigPath := profilemanager.DefaultConfigPath
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.ConfigDirOverride = tempDir
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.DefaultConfigPath = origDefaultConfigPath
profilemanager.ConfigDirOverride = ""
})
currUser, err := user.Current()
require.NoError(t, err)
profName := "test-profile"
ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: "https://api.netbird.io:443",
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
})
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false)
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
disableAutoConnect := true
networkMonitor := true
disableClientRoutes := true
disableServerRoutes := true
disableDNS := true
disableFirewall := true
blockLANAccess := true
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
req := &proto.SetConfigRequest{
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
}
_, err = s.SetConfig(ctx, req)
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()
require.NoError(t, err)
cfg, err := profilemanager.GetConfig(cfgPath)
require.NoError(t, err)
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
require.NotNil(t, cfg.NetworkMonitor)
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
require.Equal(t, disableDNS, cfg.DisableDNS)
require.Equal(t, disableFirewall, cfg.DisableFirewall)
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
require.NotNil(t, cfg.DisableNotifications)
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
// IFaceBlackList contains defaults + extras
require.Contains(t, cfg.IFaceBlackList, "eth1")
require.Contains(t, cfg.IFaceBlackList, "eth2")
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
verifyAllFieldsCovered(t, req)
}
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
// If a new field is added to SetConfigRequest, this function will fail the test,
// forcing the developer to update both the SetConfig handler and this test.
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
t.Helper()
metadataFields := map[string]bool{
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unexpectedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
if metadataFields[fieldName] {
continue
}
if !expectedFields[fieldName] {
unexpectedFields = append(unexpectedFields, fieldName)
}
}
if len(unexpectedFields) > 0 {
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
}
}
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// Map of CLI flag names to their corresponding SetConfigRequest field names.
// This map must be updated when adding new config-related CLI flags.
flagToField := map[string]string{
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
fieldsWithoutCLIFlags := map[string]bool{
"DisableNotifications": true, // Only settable via UI
}
// Get all SetConfigRequest fields to verify our map is complete.
req := &proto.SetConfigRequest{}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unmappedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
// Skip protobuf internal fields and metadata fields.
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
continue
}
if fieldName == "Username" || fieldName == "ProfileName" {
continue
}
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
continue
}
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
mappedToCLI := false
for _, mappedField := range flagToField {
if mappedField == fieldName {
mappedToCLI = true
break
}
}
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
if !mappedToCLI && !hasNoCLIFlag {
unmappedFields = append(unmappedFields, fieldName)
}
}
if len(unmappedFields) > 0 {
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
}
t.Log("All SetConfigRequest fields are properly documented")
}

View File

@@ -26,9 +26,11 @@ type mockHTTPClient struct {
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
if req.Body != nil {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
}
}
return &http.Response{
StatusCode: c.code,

View File

@@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
case "pocketid":
pocketidConfig := PocketIdClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
}
return NewPocketIdManager(pocketidConfig, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -0,0 +1,384 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type PocketIdManager struct {
managementEndpoint string
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
type pocketIdCustomClaimDto struct {
Key string `json:"key"`
Value string `json:"value"`
}
type pocketIdUserDto struct {
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
Disabled bool `json:"disabled"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
ID string `json:"id"`
IsAdmin bool `json:"isAdmin"`
LastName string `json:"lastName"`
LdapID string `json:"ldapId"`
Locale string `json:"locale"`
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
Username string `json:"username"`
}
type pocketIdUserCreateDto struct {
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
IsAdmin bool `json:"isAdmin,omitempty"`
LastName string `json:"lastName,omitempty"`
Locale string `json:"locale,omitempty"`
Username string `json:"username"`
}
type pocketIdPaginatedUserDto struct {
Data []pocketIdUserDto `json:"data"`
Pagination pocketIdPaginationDto `json:"pagination"`
}
type pocketIdPaginationDto struct {
CurrentPage int `json:"currentPage"`
ItemsPerPage int `json:"itemsPerPage"`
TotalItems int `json:"totalItems"`
TotalPages int `json:"totalPages"`
}
func (p *pocketIdUserDto) userData() *UserData {
return &UserData{
Email: p.Email,
Name: p.DisplayName,
ID: p.ID,
AppMetadata: AppMetadata{},
}
}
type pocketIdUserGroupDto struct {
CreatedAt string `json:"createdAt"`
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
FriendlyName string `json:"friendlyName"`
ID string `json:"id"`
LdapID string `json:"ldapId"`
Name string `json:"name"`
}
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ManagementEndpoint == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
}
if config.APIToken == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
}
credentials := &PocketIdCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &PocketIdManager{
managementEndpoint: config.ManagementEndpoint,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
if !slices.Contains(MethodsWithBody, method) && body != "" {
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
}
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
if query != nil {
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
}
var req *http.Request
var err error
if body != "" {
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
} else {
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
}
if err != nil {
return nil, err
}
req.Header.Add("X-API-KEY", p.apiToken)
if body != "" {
req.Header.Add("content-type", "application/json")
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
}
resp, err := p.httpClient.Do(req)
if err != nil {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// getAllUsersPaginated fetches all users from PocketID API using pagination
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
var allUsers []pocketIdUserDto
currentPage := 1
for {
params := url.Values{}
// Copy existing search parameters
for key, values := range searchParams {
params[key] = values
}
params.Set("pagination[limit]", "100")
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
var profiles pocketIdPaginatedUserDto
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
allUsers = append(allUsers, profiles.Data...)
// Check if we've reached the last page
if currentPage >= profiles.Pagination.TotalPages {
break
}
currentPage++
}
return allUsers, nil
}
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var user pocketIdUserDto
err = p.helper.Unmarshal(body, &user)
if err != nil {
return nil, err
}
userData := user.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, profile := range allUsers {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountId
users = append(users, userData)
}
return users, nil
}
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range allUsers {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.Split(name, " ")
createUser := pocketIdUserCreateDto{
Disabled: false,
DisplayName: name,
Email: email,
FirstName: firstLast[0],
LastName: firstLast[1],
Username: firstLast[0] + "." + firstLast[1],
}
payload, err := p.helper.Marshal(createUser)
if err != nil {
return nil, err
}
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
if err != nil {
return nil, err
}
var newUser pocketIdUserDto
err = p.helper.Unmarshal(body, &newUser)
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountCreateUser()
}
var pending bool = true
ret := &UserData{
Email: email,
Name: name,
ID: newUser.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pending,
WTInvitedBy: invitedByEmail,
},
}
return ret, nil
}
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
params := url.Values{
// This value a
"search": []string{email},
}
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profiles struct{ data []pocketIdUserDto }
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.data {
users = append(users, profile.userData())
}
return users, nil
}
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
if err != nil {
return err
}
return nil
}
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
if err != nil {
return err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
var _ Manager = (*PocketIdManager)(nil)
type PocketIdClientConfig struct {
APIToken string
ManagementEndpoint string
}
type PocketIdCredentials struct {
clientConfig PocketIdClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}

View File

@@ -0,0 +1,138 @@
package idp
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewPocketIdManager(t *testing.T) {
type test struct {
name string
inputConfig PocketIdClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := PocketIdClientConfig{
APIToken: "api_token",
ManagementEndpoint: "http://localhost",
}
tests := []test{
{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
},
{
name: "Missing ManagementEndpoint",
inputConfig: PocketIdClientConfig{
APIToken: defaultTestConfig.APIToken,
ManagementEndpoint: "",
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
{
name: "Missing APIToken",
inputConfig: PocketIdClientConfig{
APIToken: "",
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
})
}
}
func TestPocketID_GetUserDataByID(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
md := AppMetadata{WTAccountID: "acc1"}
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
require.NoError(t, err)
assert.Equal(t, "u1", got.ID)
assert.Equal(t, "user1@example.com", got.Email)
assert.Equal(t, "User One", got.Name)
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
}
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
// Single page response with two users
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
users, err := mgr.GetAccount(context.Background(), "accX")
require.NoError(t, err)
require.Len(t, users, 2)
assert.Equal(t, "u1", users[0].ID)
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "u2", users[1].ID)
}
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
accounts, err := mgr.GetAllAccounts(context.Background())
require.NoError(t, err)
require.Len(t, accounts[UnsetAccountID], 2)
}
func TestPocketID_CreateUser(t *testing.T) {
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
require.NoError(t, err)
assert.Equal(t, "newid", ud.ID)
assert.Equal(t, "new@example.com", ud.Email)
assert.Equal(t, "New User", ud.Name)
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
assert.True(t, *ud.AppMetadata.WTPendingInvite)
}
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
}
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
// Same mock for both calls; returns OK with empty JSON
client := &mockHTTPClient{code: 200, resBody: `{}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
err = mgr.InviteUserByID(context.Background(), "u1")
require.NoError(t, err)
err = mgr.DeleteUser(context.Background(), "u1")
require.NoError(t, err)
}

View File

@@ -1,129 +0,0 @@
package cache
import (
"context"
"sync"
)
// DualKeyCache provides a caching mechanism where each entry has two keys:
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
reverseLookup map[K1]K2 // Primary key -> Secondary key
}
// NewDualKeyCache creates a new dual-key cache
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
return &DualKeyCache[K1, K2, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
reverseLookup: make(map[K1]K2),
}
}
// Get retrieves a value from the cache using the primary key
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with both primary and secondary keys
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldSecondaryKey)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = secondaryKey
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, secondaryKey)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.reverseLookup = make(map[K1]K2)
}
// Size returns the number of entries in the cache
func (c *DualKeyCache[K1, K2, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return both the value and the secondary key (extracted from the value)
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, value)
return value, nil
}

View File

@@ -1,77 +0,0 @@
package cache
import (
"context"
"sync"
)
// SingleKeyCache provides a simple caching mechanism with a single key
type SingleKeyCache[K comparable, V any] struct {
mu sync.RWMutex
cache map[K]V // Key -> Value
}
// NewSingleKeyCache creates a new single-key cache
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
return &SingleKeyCache[K, V]{
cache: make(map[K]V),
}
}
// Get retrieves a value from the cache using the key
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.cache[key]
return value, ok
}
// Set stores a value in the cache with the given key
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = value
}
// Invalidate removes an entry using the key
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
// InvalidateAll removes all entries from the cache
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]V)
}
// Size returns the number of entries in the cache
func (c *SingleKeyCache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
if value, ok := c.Get(ctx, key); ok {
return value, nil
}
value, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, key, value)
return value, nil
}

View File

@@ -1,242 +0,0 @@
package cache
import (
"context"
"sync"
)
// TripleKeyCache provides a caching mechanism where each entry has three keys:
// - Primary key (K1): used for accessing and invalidating specific entries
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
}
type keyPair[K2 comparable, K3 comparable] struct {
secondary K2
tertiary K3
}
// NewTripleKeyCache creates a new triple-key cache
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
return &TripleKeyCache[K1, K2, K3, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
tertiaryIndex: make(map[K3]map[K1]struct{}),
reverseLookup: make(map[K1]keyPair[K2, K3]),
}
}
// Get retrieves a value from the cache using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with primary, secondary, and tertiary keys
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldKeys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, oldKeys.tertiary)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
secondary: secondaryKey,
tertiary: tertiaryKey,
}
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
}
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if keys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
delete(tertiaryPrimaryKeys, primaryKey)
if len(tertiaryPrimaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateByTertiaryKey removes all entries with the given tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
delete(secondaryPrimaryKeys, primaryKey)
if len(secondaryPrimaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.tertiaryIndex, tertiaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.tertiaryIndex = make(map[K3]map[K1]struct{})
c.reverseLookup = make(map[K1]keyPair[K2, K3])
}
// Size returns the number of entries in the cache
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -15,8 +14,6 @@ type StoreMetrics struct {
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
queryDurationMs metric.Int64Histogram
queryCounter metric.Int64Counter
ctx context.Context
}
@@ -62,29 +59,12 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err
}
queryDurationMs, err := meter.Int64Histogram("management.store.query.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of database query operations with operation type and table name"),
)
if err != nil {
return nil, err
}
queryCounter, err := meter.Int64Counter("management.store.query.count",
metric.WithDescription("Count of database query operations with operation type, table name, and status"),
)
if err != nil {
return nil, err
}
return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
queryDurationMs: queryDurationMs,
queryCounter: queryCounter,
ctx: ctx,
}, nil
}
@@ -105,13 +85,3 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}
// CountStoreOperation records a store operation with its method name, status, and duration
func (metrics *StoreMetrics) CountStoreOperation(method string, duration time.Duration) {
attrs := []attribute.KeyValue{
attribute.String("method", method),
}
metrics.queryDurationMs.Record(metrics.ctx, duration.Milliseconds(), metric.WithAttributes(attrs...))
metrics.queryCounter.Add(metrics.ctx, 1, metric.WithAttributes(attrs...))
}

View File

@@ -42,6 +42,7 @@ get_release() {
curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
else
curl -v "${URL}" > out.log
curl -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
fi