Compare commits

...

13 Commits

Author SHA1 Message Date
braginini
03d0f62ccd Add GroupMinimum to the SetupKey response 2022-09-12 14:30:46 +02:00
Misha Bragin
be7d829858 Add SetupKey auto-groups property (#460) 2022-09-11 23:16:40 +02:00
Maycon Santos
ed1872560f Use the client network for log errors (#455) 2022-09-07 18:26:59 +02:00
Maycon Santos
de898899a4 update slack invite tittle 2022-09-05 18:44:04 +02:00
Maycon Santos
b63ec71aed Check if login stream was canceled before printing warn (#451) 2022-09-05 17:44:26 +02:00
Maycon Santos
1012172f04 Add routing peer support (#441)
Handle routes updates from management

Manage routing firewall rules

Manage peer RIB table

Add get peer and get notification channel from the status recorder

Update interface peers allowed IPs
2022-09-05 09:06:35 +02:00
Maycon Santos
788bb00ef1 Fix service install when sysV service bin exists (#450) 2022-09-05 08:56:07 +02:00
Maycon Santos
4e5ee70b3d Load WgPort from config file and exchange via signal (#449)
Added additional common blacklisted interfaces

Updated the signal protocol to pass the peer port and netbird version

Co-authored-by: braginini <bangvalo@gmail.com>
2022-09-02 19:33:35 +02:00
Maycon Santos
f1c00ae543 Update service library with rcS init system support (#447) 2022-09-02 14:03:02 +02:00
Misha Bragin
553a13588b Free up gRPC client resources on errors (#448) 2022-09-01 18:28:45 +02:00
Maycon Santos
586c0f5c3d Log remote address when not registered (#445) 2022-08-27 17:55:05 +02:00
Maycon Santos
c13f0b9f07 Use select for turn credentials and peers update (#443)
Also, prevent peer update when SSH is the same
2022-08-27 12:57:03 +02:00
Misha Bragin
dd4ff61b51 Do not autoload authissuer for the IDPManager config (#442) 2022-08-25 09:24:24 +02:00
54 changed files with 4147 additions and 448 deletions

View File

@@ -16,7 +16,7 @@
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@wiretrustee-red.svg?logo=slack"/>
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
</p>
</div>
@@ -43,20 +43,20 @@ It requires zero configuration effort leaving behind the hassle of opening ports
NetBird creates an overlay peer-to-peer network connecting machines automatically regardless of their location (home, office, datacenter, container, cloud or edge environments) unifying virtual private network management experience.
**Key features:**
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Automatic IP allocation and network management with a Web UI ([separate repo](https://github.com/netbirdio/dashboard))
- \[x] Automatic WireGuard peer (machine) discovery and configuration.
- \[x] Encrypted peer-to-peer connections without a central VPN gateway.
- \[x] Connection relay fallback in case a peer-to-peer connection is not possible.
- \[x] Desktop client applications for Linux, MacOS, and Windows (systray).
- \[x] Multiuser support - sharing network between multiple users.
- \[x] SSO and MFA support.
- \[x] Multicloud and hybrid-cloud support.
- \[x] Kernel WireGuard usage when possible.
- \[x] Access Controls - groups & rules.
- \[x] Remote SSH access without managing SSH keys.
- \[x] Network Routes.
**Coming soon:**
- \[ ] Router nodes
- \[ ] Private DNS.
- \[ ] Mobile clients.
- \[ ] Network Activity Monitoring.

View File

@@ -37,6 +37,7 @@ type Config struct {
ManagementURL *url.URL
AdminURL *url.URL
WgIface string
WgPort int
IFaceBlackList []string
// SSHKey is a private SSH key in a PEM format
SSHKey string
@@ -49,7 +50,13 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
if err != nil {
return nil, err
}
config := &Config{SSHKey: string(pem), PrivateKey: wgKey, WgIface: iface.WgInterfaceDefault, IFaceBlackList: []string{}}
config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
WgIface: iface.WgInterfaceDefault,
WgPort: iface.DefaultWgPort,
IFaceBlackList: []string{},
}
if managementURL != "" {
URL, err := ParseURL("Management URL", managementURL)
if err != nil {
@@ -72,8 +79,8 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
config.AdminURL = newURL
}
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"}
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale", "docker", "vet"}
err = util.WriteJson(configPath, config)
if err != nil {
@@ -150,6 +157,11 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
refresh = true
}
if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
refresh = true
}
if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(configPath, config); err != nil {
@@ -226,7 +238,13 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
log.Errorf("failed connecting to Management Service %s %v", config.ManagementURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to management Service %s", config.ManagementURL.String())
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
@@ -245,12 +263,6 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
}
}
err = mgmClient.Close()
if err != nil {
log.Errorf("failed closing Management Service client: %v", err)
return DeviceAuthorizationFlow{}, err
}
return DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),

View File

@@ -79,9 +79,21 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
cancel()
}()
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
}
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
mgmClient, loginResp, err := connectToManagement(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled,
publicSSHKey)
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
if err != nil {
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
@@ -114,6 +126,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
log.Error(err)
return wrapErr(err)
}
defer func() {
err = signalClient.Close()
if err != nil {
log.Warnf("failed closing Signal service client %v", err)
}
}()
statusRecorder.MarkSignalConnected(signalURL)
@@ -139,18 +157,6 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *nbStatus.Sta
backOff.Reset()
err = mgmClient.Close()
if err != nil {
log.Errorf("failed closing Management Service client %v", err)
return wrapErr(err)
}
err = signalClient.Close()
if err != nil {
log.Errorf("failed closing Signal Service client %v", err)
return wrapErr(err)
}
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
@@ -182,7 +188,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList,
WgPrivateKey: key,
WgPort: iface.DefaultWgPort,
WgPort: config.WgPort,
SSHKey: []byte(config.SSHKey),
}
@@ -215,34 +221,26 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
return signalClient, nil
}
// connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool, pubSSHKey []byte) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) {
log.Debugf("connecting to Management Service %s", managementAddr)
client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled)
if err != nil {
return nil, nil, gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)
}
log.Debugf("connected to management server %s", managementAddr)
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
if err != nil {
return nil, nil, err
return nil, err
}
log.Debugf("peer logged in to Management Service %s", managementAddr)
return client, loginResp, nil
return loginResp, nil
}
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
// NB: hardcoded from github.com/netbirdio/netbird/management/cmd to avoid import
const ManagementLegacyPort = 33073
// UpdateOldManagementPort checks whether client can switch to the new Management port 443.
@@ -286,7 +284,12 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
defer client.Close() //nolint
defer func() {
err = client.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// gRPC check
_, err = client.GetServerPublicKey()

View File

@@ -3,8 +3,10 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/route"
"math/rand"
"net"
"reflect"
@@ -99,6 +101,8 @@ type Engine struct {
sshServer nbssh.Server
statusRecorder *nbstatus.Status
routeManager routemanager.Manager
}
// Peer is an instance of the Connection Peer
@@ -182,6 +186,10 @@ func (e *Engine) Stop() error {
}
}
if e.routeManager != nil {
e.routeManager.Stop()
}
log.Infof("stopped Netbird Engine")
return nil
@@ -232,6 +240,8 @@ func (e *Engine) Start() error {
return err
}
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -388,7 +398,8 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp
return nil
}
func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
// SignalOfferAnswer signals either an offer or an answer to remote peer
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error {
var t sProto.Body_Type
if isAnswer {
t = sProto.Body_ANSWER
@@ -396,9 +407,9 @@ func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.K
t = sProto.Body_OFFER
}
msg, err := signal.MarshalCredential(myKey, remoteKey, &signal.Credential{
UFrag: uFrag,
Pwd: pwd,
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
}, t)
if err != nil {
return err
@@ -618,11 +629,37 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update routes, err: %v", err)
}
e.networkSerial = serial
return nil
}
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
_, prefix, _ := route.ParseNetwork(protoRoute.Network)
convertedRoute := &route.Route{
ID: protoRoute.ID,
Network: prefix,
NetID: protoRoute.NetID,
NetworkType: route.NetworkType(protoRoute.NetworkType),
Peer: protoRoute.Peer,
Metric: int(protoRoute.Metric),
Masquerade: protoRoute.Masquerade,
}
routes = append(routes, convertedRoute)
}
return routes
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {
@@ -726,6 +763,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
UDPMux: e.udpMux,
UDPMuxSrflx: e.udpMuxSrflx,
ProxyConfig: proxyConfig,
LocalWgPort: e.config.WgPort,
}
peerConn, err := peer.NewConn(config, e.statusRecorder)
@@ -738,16 +776,16 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
return nil, err
}
signalOffer := func(uFrag string, pwd string) error {
return signalAuth(uFrag, pwd, e.config.WgPrivateKey, wgPubKey, e.signal, false)
signalOffer := func(offerAnswer peer.OfferAnswer) error {
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
}
signalCandidate := func(candidate ice.Candidate) error {
return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
}
signalAnswer := func(uFrag string, pwd string) error {
return signalAuth(uFrag, pwd, e.config.WgPrivateKey, wgPubKey, e.signal, true)
signalAnswer := func(offerAnswer peer.OfferAnswer) error {
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
}
peerConn.SetSignalCandidate(signalCandidate)
@@ -776,18 +814,26 @@ func (e *Engine) receiveSignalEvents() {
if err != nil {
return err
}
conn.OnRemoteOffer(peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
conn.OnRemoteAnswer(peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)

View File

@@ -3,11 +3,14 @@ package internal
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
"net"
"net/netip"
"os"
"path/filepath"
"runtime"
@@ -196,6 +199,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgPort: 33100,
}, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder)
type testCase struct {
name string
@@ -426,6 +430,142 @@ func TestEngine_Sync(t *testing.T) {
}
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
testCases := []struct {
name string
inputErr error
networkMap *mgmtProto.NetworkMap
expectedLen int
expectedRoutes []*route.Route
expectedSerial uint64
}{
{
name: "Routes Update Should Be Passed To Manager",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: []*mgmtProto.Route{
{
ID: "a",
Network: "192.168.0.0/24",
NetID: "n1",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
{
ID: "b",
Network: "192.168.1.0/24",
NetID: "n2",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
},
},
expectedLen: 2,
expectedRoutes: []*route.Route{
{
ID: "a",
Network: netip.MustParsePrefix("192.168.0.0/24"),
NetID: "n1",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
{
ID: "b",
Network: netip.MustParsePrefix("192.168.1.0/24"),
NetID: "n2",
Peer: "p1",
NetworkType: 1,
Masquerade: false,
},
},
expectedSerial: 1,
},
{
name: "Empty Routes Update Should Be Passed",
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedLen: 0,
expectedRoutes: []*route.Route{},
expectedSerial: 1,
},
{
name: "Error Shouldn't Break Engine",
inputErr: fmt.Errorf("mocking error"),
networkMap: &mgmtProto.NetworkMap{
Serial: 1,
PeerConfig: nil,
RemotePeersIsEmpty: false,
Routes: nil,
},
expectedLen: 0,
expectedRoutes: []*route.Route{},
expectedSerial: 1,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
// test setup
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, nbstatus.NewRecorder())
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU)
assert.NoError(t, err, "shouldn't return error")
input := struct {
inputSerial uint64
inputRoutes []*route.Route
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return testCase.inputErr
},
}
engine.routeManager = mockRouteManager
defer func() {
exitErr := engine.Stop()
if exitErr != nil {
return
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
})
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)

View File

@@ -26,13 +26,22 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
mgmTlsEnabled = true
}
log.Debugf("connecting to Management Service %s", config.ManagementURL.String())
log.Debugf("connecting to the Management service %s", config.ManagementURL.String())
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", config.ManagementURL.String(), err)
log.Errorf("failed connecting to the Management service %s %v", config.ManagementURL.String(), err)
return err
}
log.Debugf("connected to management Service %s", config.ManagementURL.String())
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
@@ -53,7 +62,7 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
err = mgmClient.Close()
if err != nil {
log.Errorf("failed closing Management Service client: %v", err)
log.Errorf("failed to close the Management service client: %v", err)
return err
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"context"
nbStatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"golang.zx2c4.com/wireguard/wgctrl"
"net"
@@ -36,6 +37,20 @@ type ConnConfig struct {
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
LocalWgPort int
}
// OfferAnswer represents a session establishment offer or answer
type OfferAnswer struct {
IceCredentials IceCredentials
// WgListenPort is a remote WireGuard listen port.
// This field is used when establishing a direct WireGuard connection without any proxy.
// We can set the remote peer's endpoint with this port.
WgListenPort int
// Version of NetBird Agent
Version string
}
// IceCredentials ICE protocol credentials struct
@@ -51,13 +66,13 @@ type Conn struct {
// signalCandidate is a handler function to signal remote peer about local connection candidate
signalCandidate func(candidate ice.Candidate) error
// signalOffer is a handler function to signal remote peer our connection offer (credentials)
signalOffer func(uFrag string, pwd string) error
signalAnswer func(uFrag string, pwd string) error
signalOffer func(OfferAnswer) error
signalAnswer func(OfferAnswer) error
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan IceCredentials
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
remoteAnswerCh chan IceCredentials
remoteAnswerCh chan OfferAnswer
closeCh chan struct{}
ctx context.Context
notifyDisconnected context.CancelFunc
@@ -88,8 +103,8 @@ func NewConn(config ConnConfig, statusRecorder *nbStatus.Status) (*Conn, error)
mu: sync.Mutex{},
status: StatusDisconnected,
closeCh: make(chan struct{}),
remoteOffersCh: make(chan IceCredentials),
remoteAnswerCh: make(chan IceCredentials),
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder,
}, nil
}
@@ -200,15 +215,15 @@ func (conn *Conn) Open() error {
// Only continue once we got a connection confirmation from the remote peer.
// The connection timeout could have happened before a confirmation received from the remote.
// The connection could have also been closed externally (e.g. when we received an update from the management that peer shouldn't be connected)
var remoteCredentials IceCredentials
var remoteOfferAnswer OfferAnswer
select {
case remoteCredentials = <-conn.remoteOffersCh:
case remoteOfferAnswer = <-conn.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
err = conn.sendAnswer()
if err != nil {
return err
}
case remoteCredentials = <-conn.remoteAnswerCh:
case remoteOfferAnswer = <-conn.remoteAnswerCh:
case <-time.After(conn.config.Timeout):
return NewConnectionTimeoutError(conn.config.Key, conn.config.Timeout)
case <-conn.closeCh:
@@ -216,7 +231,8 @@ func (conn *Conn) Open() error {
return NewConnectionClosedError(conn.config.Key)
}
log.Debugf("received connection confirmation from peer %s", conn.config.Key)
log.Debugf("received connection confirmation from peer %s running version %s and with remote WireGuard listen port %d",
conn.config.Key, remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
// at this point we received offer/answer and we are ready to gather candidates
conn.mu.Lock()
@@ -245,16 +261,21 @@ func (conn *Conn) Open() error {
isControlling := conn.config.LocalKey > conn.config.Key
var remoteConn *ice.Conn
if isControlling {
remoteConn, err = conn.agent.Dial(conn.ctx, remoteCredentials.UFrag, remoteCredentials.Pwd)
remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
remoteConn, err = conn.agent.Accept(conn.ctx, remoteCredentials.UFrag, remoteCredentials.Pwd)
remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
if err != nil {
return err
}
// dynamically set remote WireGuard port is other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
// the ice connection has been established successfully so we are ready to start the proxy
err = conn.startProxy(remoteConn)
err = conn.startProxy(remoteConn, remoteWgPort)
if err != nil {
return err
}
@@ -319,7 +340,7 @@ func IsPublicIP(ip net.IP) bool {
}
// startProxy starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
func (conn *Conn) startProxy(remoteConn net.Conn) error {
func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -336,7 +357,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn) error {
p = proxy.NewWireguardProxy(conn.config.ProxyConfig)
peerState.Direct = false
} else {
p = proxy.NewNoProxy(conn.config.ProxyConfig)
p = proxy.NewNoProxy(conn.config.ProxyConfig, remoteWgPort)
peerState.Direct = true
}
conn.proxy = p
@@ -409,12 +430,12 @@ func (conn *Conn) cleanup() error {
}
// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
func (conn *Conn) SetSignalOffer(handler func(uFrag string, pwd string) error) {
func (conn *Conn) SetSignalOffer(handler func(offer OfferAnswer) error) {
conn.signalOffer = handler
}
// SetSignalAnswer sets a handler function to be triggered by Conn when a new connection answer has to be signalled to the remote peer
func (conn *Conn) SetSignalAnswer(handler func(uFrag string, pwd string) error) {
func (conn *Conn) SetSignalAnswer(handler func(answer OfferAnswer) error) {
conn.signalAnswer = handler
}
@@ -459,8 +480,12 @@ func (conn *Conn) sendAnswer() error {
return err
}
log.Debugf("sending asnwer to %s", conn.config.Key)
err = conn.signalAnswer(localUFrag, localPwd)
log.Debugf("sending answer to %s", conn.config.Key)
err = conn.signalAnswer(OfferAnswer{
IceCredentials: IceCredentials{localUFrag, localPwd},
WgListenPort: conn.config.LocalWgPort,
Version: system.NetbirdVersion(),
})
if err != nil {
return err
}
@@ -477,7 +502,11 @@ func (conn *Conn) sendOffer() error {
if err != nil {
return err
}
err = conn.signalOffer(localUFrag, localPwd)
err = conn.signalOffer(OfferAnswer{
IceCredentials: IceCredentials{localUFrag, localPwd},
WgListenPort: conn.config.LocalWgPort,
Version: system.NetbirdVersion(),
})
if err != nil {
return err
}
@@ -518,11 +547,11 @@ func (conn *Conn) Status() ConnStatus {
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteOffer(remoteAuth IceCredentials) bool {
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
log.Debugf("OnRemoteOffer from peer %s on status %s", conn.config.Key, conn.status.String())
select {
case conn.remoteOffersCh <- remoteAuth:
case conn.remoteOffersCh <- offer:
return true
default:
log.Debugf("OnRemoteOffer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String())
@@ -533,11 +562,11 @@ func (conn *Conn) OnRemoteOffer(remoteAuth IceCredentials) bool {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(remoteAuth IceCredentials) bool {
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
log.Debugf("OnRemoteAnswer from peer %s on status %s", conn.config.Key, conn.status.String())
select {
case conn.remoteAnswerCh <- remoteAuth:
case conn.remoteAnswerCh <- answer:
return true
default:
// connection might not be ready yet to receive so we ignore the message

View File

@@ -18,6 +18,7 @@ var connConf = ConnConfig{
InterfaceBlackList: nil,
Timeout: time.Second,
ProxyConfig: proxy.Config{},
LocalWgPort: 51820,
}
func TestNewConn_interfaceFilter(t *testing.T) {
@@ -59,9 +60,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
go func() {
for {
accepted := conn.OnRemoteOffer(IceCredentials{
UFrag: "test",
Pwd: "test",
accepted := conn.OnRemoteOffer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()
@@ -89,9 +94,13 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
go func() {
for {
accepted := conn.OnRemoteAnswer(IceCredentials{
UFrag: "test",
Pwd: "test",
accepted := conn.OnRemoteAnswer(OfferAnswer{
IceCredentials: IceCredentials{
UFrag: "test",
Pwd: "test",
},
WgListenPort: 0,
Version: "",
})
if accepted {
wg.Done()

View File

@@ -1,7 +1,6 @@
package proxy
import (
"github.com/netbirdio/netbird/iface"
log "github.com/sirupsen/logrus"
"net"
)
@@ -14,10 +13,14 @@ import (
// In order NoProxy to work, Wireguard port has to be fixed for the time being.
type NoProxy struct {
config Config
// RemoteWgListenPort is a WireGuard port of a remote peer.
// It is used instead of the hardcoded 51820 port.
RemoteWgListenPort int
}
func NewNoProxy(config Config) *NoProxy {
return &NoProxy{config: config}
// NewNoProxy creates a new NoProxy with a provided config and remote peer's WireGuard listen port
func NewNoProxy(config Config, remoteWgPort int) *NoProxy {
return &NoProxy{config: config, RemoteWgListenPort: remoteWgPort}
}
func (p *NoProxy) Close() error {
@@ -36,7 +39,7 @@ func (p *NoProxy) Start(remoteConn net.Conn) error {
if err != nil {
return err
}
addr.Port = iface.DefaultWgPort
addr.Port = p.RemoteWgListenPort
err = p.config.WgInterface.UpdatePeer(p.config.RemoteKey, p.config.AllowedIps, DefaultWgKeepAlive,
addr, p.config.PreSharedKey)

View File

@@ -0,0 +1,285 @@
package routemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"net/netip"
)
type routerPeerStatus struct {
connected bool
relayed bool
direct bool
}
type routesUpdate struct {
updateSerial uint64
routes []*route.Route
}
type clientNetwork struct {
ctx context.Context
stop context.CancelFunc
statusRecorder *status.Status
wgInterface *iface.WGIface
routes map[string]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
routePeersNotifiers map[string]chan struct{}
chosenRoute *route.Route
network netip.Prefix
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *status.Status, network netip.Prefix) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
ctx: ctx,
stop: cancel,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
routes: make(map[string]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
network: network,
}
return client
}
func getClientNetworkID(input *route.Route) string {
return input.NetID + "-" + input.Network.String()
}
func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
routePeerStatuses := make(map[string]routerPeerStatus)
for _, r := range c.routes {
peerStatus, err := c.statusRecorder.GetPeer(r.Peer)
if err != nil {
log.Debugf("couldn't fetch peer state: %v", err)
continue
}
routePeerStatuses[r.ID] = routerPeerStatus{
connected: peerStatus.ConnStatus == peer.StatusConnected.String(),
relayed: peerStatus.Relayed,
direct: peerStatus.Direct,
}
}
return routePeerStatuses
}
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
var chosen string
chosenScore := 0
currID := ""
if c.chosenRoute != nil {
currID = c.chosenRoute.ID
}
for _, r := range c.routes {
tempScore := 0
peerStatus, found := routePeerStatuses[r.ID]
if !found || !peerStatus.connected {
continue
}
if r.Metric < route.MaxMetric {
metricDiff := route.MaxMetric - r.Metric
tempScore = metricDiff * 10
}
if !peerStatus.relayed {
tempScore++
}
if !peerStatus.direct {
tempScore++
}
if tempScore > chosenScore || (tempScore == chosenScore && currID == r.ID) {
chosen = r.ID
chosenScore = tempScore
}
}
if chosen == "" {
var peers []string
for _, r := range c.routes {
peers = append(peers, r.Peer)
}
log.Warnf("no route was chosen for network %s because no peers from list %s were connected", c.network, peers)
} else if chosen != currID {
log.Infof("new chosen route is %s with peer %s with score %d", chosen, c.routes[chosen].Peer, chosenScore)
}
return chosen
}
func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey string, peerStateUpdate chan struct{}, closer chan struct{}) {
for {
select {
case <-ctx.Done():
return
case <-closer:
return
case <-c.statusRecorder.GetPeerStateChangeNotifier(peerKey):
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil || state.ConnStatus == peer.StatusConnecting.String() {
continue
}
peerStateUpdate <- struct{}{}
log.Debugf("triggered route state update for Peer %s, state: %s", peerKey, state.ConnStatus)
}
}
}
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
}
}
}
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
state, err := c.statusRecorder.GetPeer(peerKey)
if err != nil || state.ConnStatus != peer.StatusConnected.String() {
return nil
}
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
if err != nil {
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil
}
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
if c.chosenRoute != nil {
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
if err != nil {
return err
}
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.GetAddress().IP.String())
if err != nil {
return fmt.Errorf("couldn't remove route %s from system, err: %v",
c.network, err)
}
}
return nil
}
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
var err error
routerPeerStatuses := c.getRouterPeerStatuses()
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
if chosen == "" {
err = c.removeRouteFromPeerAndSystem()
if err != nil {
return err
}
c.chosenRoute = nil
return nil
}
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
if c.chosenRoute.IsEqual(c.routes[chosen]) {
return nil
}
}
if c.chosenRoute != nil {
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
if err != nil {
return err
}
} else {
err = addToRouteTableIfNoExists(c.network, c.wgInterface.GetAddress().IP.String())
if err != nil {
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
c.network.String(), c.wgInterface.GetAddress().IP.String(), err)
}
}
c.chosenRoute = c.routes[chosen]
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
if err != nil {
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
c.network, c.chosenRoute.Peer, err)
}
return nil
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update
}()
}
func (c *clientNetwork) handleUpdate(update routesUpdate) {
updateMap := make(map[string]*route.Route)
for _, r := range update.routes {
updateMap[r.ID] = r
}
for id, r := range c.routes {
_, found := updateMap[id]
if !found {
close(c.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer)
}
}
c.routes = updateMap
}
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
// All the processing related to the client network should be done here. Thread-safe.
func (c *clientNetwork) peersStateAndUpdateWatcher() {
for {
select {
case <-c.ctx.Done():
log.Debugf("stopping watcher for network %s", c.network)
err := c.removeRouteFromPeerAndSystem()
if err != nil {
log.Error(err)
}
return
case <-c.peerStateUpdate:
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Error(err)
}
case update := <-c.routeUpdate:
if update.updateSerial < c.updateSerial {
log.Warnf("received a routes update with smaller serial number, ignoring it")
continue
}
log.Debugf("received a new client network route update for %s", c.network)
c.handleUpdate(update)
c.updateSerial = update.updateSerial
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Error(err)
}
c.startPeersStatusChangeWatcher()
}
}
}

View File

@@ -0,0 +1,75 @@
package routemanager
var insertRuleTestCases = []struct {
name string
inputPair routerPair
ipVersion string
}{
{
name: "Insert Forwarding IPV4 Rule",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: false,
},
ipVersion: ipv4,
},
{
name: "Insert Forwarding And Nat IPV4 Rules",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: true,
},
ipVersion: ipv4,
},
{
name: "Insert Forwarding IPV6 Rule",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: false,
},
ipVersion: ipv6,
},
{
name: "Insert Forwarding And Nat IPV6 Rules",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: true,
},
ipVersion: ipv6,
},
}
var removeRuleTestCases = []struct {
name string
inputPair routerPair
ipVersion string
}{
{
name: "Remove Forwarding And Nat IPV4 Rules",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: true,
},
ipVersion: ipv4,
},
{
name: "Remove Forwarding And Nat IPV6 Rules",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: true,
},
ipVersion: ipv6,
},
}

View File

@@ -0,0 +1,12 @@
package routemanager
type firewallManager interface {
// RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules
RestoreOrCreateContainers() error
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair routerPair) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair routerPair) error
// CleanRoutingRules cleans a firewall set of containers
CleanRoutingRules()
}

View File

@@ -0,0 +1,55 @@
package routemanager
import (
"context"
"fmt"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
)
import "github.com/google/nftables"
const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
)
func genKey(format string, input string) string {
return fmt.Sprintf(format, input)
}
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func NewFirewall(parentCTX context.Context) firewallManager {
ctx, cancel := context.WithCancel(parentCTX)
if isIptablesSupported() {
log.Debugf("iptables is supported")
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
return &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
}
log.Debugf("iptables is not supported, using nftables")
manager := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
return manager
}

View File

@@ -0,0 +1,27 @@
//go:build !linux
// +build !linux
package routemanager
import "context"
type unimplementedFirewall struct{}
func (unimplementedFirewall) RestoreOrCreateContainers() error {
return nil
}
func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error {
return nil
}
func (unimplementedFirewall) CleanRoutingRules() {
return
}
// NewFirewall returns an unimplemented Firewall manager
func NewFirewall(parentCtx context.Context) firewallManager {
return unimplementedFirewall{}
}

View File

@@ -0,0 +1,403 @@
package routemanager
import (
"context"
"fmt"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"net/netip"
"os/exec"
"strings"
"sync"
)
func isIptablesSupported() bool {
_, err4 := exec.LookPath("iptables")
_, err6 := exec.LookPath("ip6tables")
return err4 == nil && err6 == nil
}
// constants needed to manage and create iptable rules
const (
iptablesFilterTable = "filter"
iptablesNatTable = "nat"
iptablesForwardChain = "FORWARD"
iptablesPostRoutingChain = "POSTROUTING"
iptablesRoutingNatChain = "NETBIRD-RT-NAT"
iptablesRoutingForwardingChain = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
)
// some presets for building nftable rules
var (
iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"}
iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"}
iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"}
iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"}
)
type iptablesManager struct {
ctx context.Context
stop context.CancelFunc
ipv4Client *iptables.IPTables
ipv6Client *iptables.IPTables
rules map[string]map[string][]string
mux sync.Mutex
}
// CleanRoutingRules cleans existing iptables resources that we created by the agent
func (i *iptablesManager) CleanRoutingRules() {
i.mux.Lock()
defer i.mux.Unlock()
err := i.cleanJumpRules()
if err != nil {
log.Error(err)
}
log.Debug("flushing tables")
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
log.Info("done cleaning up iptables rules")
}
// RestoreOrCreateContainers restores existing iptables containers (chains and rules)
// if they don't exist, we create them
func (i *iptablesManager) RestoreOrCreateContainers() error {
i.mux.Lock()
defer i.mux.Unlock()
if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil {
return nil
}
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
err = i.restoreRules(i.ipv6Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *iptablesManager) addJumpRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
}
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv4Nat)
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Nat] = rule
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv6Nat)
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *iptablesManager) cleanJumpRules() error {
var err error
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
rule, found := i.rules[ipv4][ipv4Forwarding]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv4][ipv4Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
}
}
rule, found = i.rules[ipv6][ipv6Forwarding]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv6][ipv6Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
}
}
return nil
}
func iptablesProtoToString(proto iptables.Protocol) string {
if proto == iptables.ProtocolIPv6 {
return ipv6
}
return ipv4
}
// restoreRules restores existing NetBird rules
func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error {
ipVersion := iptablesProtoToString(iptablesClient.Proto())
if i.rules[ipVersion] == nil {
i.rules[ipVersion] = make(map[string][]string)
}
table := iptablesFilterTable
for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} {
rules, err := iptablesClient.List(table, chain)
if err != nil {
return err
}
for _, ruleString := range rules {
rule := strings.Fields(ruleString)
id := getRuleRouteID(rule)
if id != "" {
i.rules[ipVersion][id] = rule[2:]
}
}
}
table = iptablesNatTable
for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} {
rules, err := iptablesClient.List(table, chain)
if err != nil {
return err
}
for _, ruleString := range rules {
rule := strings.Fields(ruleString)
id := getRuleRouteID(rule)
if id != "" {
i.rules[ipVersion][id] = rule[2:]
}
}
}
return nil
}
// createChain create NetBird chains
func createChain(iptables *iptables.IPTables, table, newChain string) error {
chains, err := iptables.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = iptables.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err)
}
if table == iptablesNatTable {
err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...)
} else {
err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...)
}
if err != nil {
return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err)
}
}
return nil
}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
// getRuleRouteID returns the rule ID if matches our prefix
func getRuleRouteID(rule []string) string {
for i, flag := range rule {
if flag == "--comment" {
id := rule[i+1]
if strings.HasPrefix(id, "netbird-") {
return id
}
}
}
return ""
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
if prefix.Addr().Unmap().Is6() {
iptablesClient = i.ipv6Client
ipVersion = ipv6
}
forwardRuleKey := genKey(forwardingFormat, pair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][forwardRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], forwardRuleKey)
}
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err)
}
i.rules[ipVersion][forwardRuleKey] = forwardRule
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], natRuleKey)
}
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
}
i.rules[ipVersion][natRuleKey] = natRule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
if prefix.Addr().Unmap().Is6() {
iptablesClient = i.ipv6Client
ipVersion = ipv6
}
forwardRuleKey := genKey(forwardingFormat, pair.ID)
existingRule, found := i.rules[ipVersion][forwardRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err)
}
}
delete(i.rules[ipVersion], forwardRuleKey)
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
}
}
delete(i.rules[ipVersion], natRuleKey)
return nil
}

View File

@@ -0,0 +1,247 @@
package routemanager
import (
"context"
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
"testing"
)
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist")
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist")
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist")
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist")
pair := routerPair{
ID: "abc",
source: "100.100.100.1/32",
destination: "100.100.100.0/24",
masquerade: true,
}
forward4RuleKey := genKey(forwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := genKey(natFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
pair = routerPair{
ID: "abc",
source: "fc00::1/128",
destination: "fc11::/64",
masquerade: true,
}
forward6RuleKey := genKey(forwardingFormat, pair.ID)
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat6RuleKey := genKey(natFormat, pair.ID)
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4")
foundRule, found := manager.rules[ipv4][forward4RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match")
foundRule, found = manager.rules[ipv4][nat4RuleKey]
require.True(t, found, "nat rule should exist in the map")
require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match")
require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6")
foundRule, found = manager.rules[ipv6][forward6RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match")
foundRule, found = manager.rules[ipv6][nat6RuleKey]
require.True(t, found, "nat rule should exist in the map")
require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
iptablesClient := ipv4Client
if testCase.ipVersion == ipv6 {
iptablesClient = ipv6Client
}
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should exist in the map")
}
})
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
iptablesClient := ipv4Client
if testCase.ipVersion == ipv6 {
iptablesClient = ipv6Client
}
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.inputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
})
}
}

View File

@@ -0,0 +1,181 @@
package routemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"runtime"
"sync"
)
// Manager is a route manager interface
type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
Stop()
}
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
serverRoutes map[string]*route.Route
serverRouter *serverRouter
statusRecorder *status.Status
wgInterface *iface.WGIface
pubKey string
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *status.Status) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
return &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRoutes: make(map[string]*route.Route),
serverRouter: &serverRouter{
routes: make(map[string]*route.Route),
netForwardHistoryEnabled: isNetForwardHistoryEnabled(),
firewall: NewFirewall(ctx),
},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
}
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()
m.serverRouter.firewall.CleanRoutingRules()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
update := routesUpdate{
updateSerial: updateSerial,
routes: routes,
}
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(update)
}
}
func (m *DefaultManager) updateServerRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.serverRouter.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.serverRoutes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.serverRoutes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
continue
}
}
for _, routeID := range serverRoutesToRemove {
oldRoute := m.serverRoutes[routeID]
err := m.removeFromServerNetwork(oldRoute)
if err != nil {
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
oldRoute.ID, oldRoute.Network, err)
}
delete(m.serverRoutes, routeID)
}
for id, newRoute := range routesMap {
_, found := m.serverRoutes[id]
if found {
continue
}
err := m.addToServerNetwork(newRoute)
if err != nil {
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
continue
}
m.serverRoutes[id] = newRoute
}
if len(m.serverRoutes) > 0 {
err := enableIPForwarding()
if err != nil {
return err
}
}
return nil
}
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
for _, newRoute := range newRoutes {
// only linux is supported for now
if newRoute.Peer == m.pubKey {
if runtime.GOOS != "linux" {
log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS)
continue
}
newServerRoutesMap[newRoute.ID] = newRoute
} else {
// if prefix is too small, lets assume is a possible default route which is not yet supported
// we skip this route management
if newRoute.Network.Bits() < 7 {
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skiping this route",
system.NetbirdVersion(), newRoute.Network)
continue
}
clientNetworkID := getClientNetworkID(newRoute)
newClientRoutesIDMap[clientNetworkID] = append(newClientRoutesIDMap[clientNetworkID], newRoute)
}
}
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
err := m.updateServerRoutes(newServerRoutesMap)
if err != nil {
return err
}
return nil
}
}

View File

@@ -0,0 +1,370 @@
package routemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/require"
"net/netip"
"runtime"
"testing"
)
// send 5 routes, one for server and 4 for clients, one normal and 2 HA and one small
// if linux host, should have one for server in map
// we should have 2 client manager
// 2 ranges in our routing table
const localPeerKey = "local"
const remotePeerKey1 = "remote1"
const remotePeerKey2 = "remote1"
func TestManagerUpdateRoutes(t *testing.T) {
testCases := []struct {
name string
inputInitRoutes []*route.Route
inputRoutes []*route.Route
inputSerial uint64
shouldCheckServerRoutes bool
serverRoutesExpected int
clientNetworkWatchersExpected int
}{
{
name: "Should create 2 client networks",
inputInitRoutes: []*route.Route{},
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.8.8/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 2,
},
{
name: "Should Create 2 Server Routes",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.252.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: localPeerKey,
Network: netip.MustParsePrefix("8.8.8.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 2,
clientNetworkWatchersExpected: 0,
},
{
name: "Should Create 1 Route For Client And Server",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.30.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS == "linux",
serverRoutesExpected: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Should Create 1 HA Route and 1 Standalone",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.20.0/24"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeA",
Peer: remotePeerKey2,
Network: netip.MustParsePrefix("8.8.20.0/24"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "c",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.9.9/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 2,
},
{
name: "No Small Client Route Should Be Added",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
},
{
name: "No Server Routes Should Be Added To Non Linux",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("1.2.3.4/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
shouldCheckServerRoutes: runtime.GOOS != "linux",
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove 1 Client Route",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.8.8/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Update Route to HA",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.8.8/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeA",
Peer: remotePeerKey2,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,
clientNetworkWatchersExpected: 1,
},
{
name: "Remove Client Routes",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.8.8/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{},
inputSerial: 1,
clientNetworkWatchersExpected: 0,
},
{
name: "Remove All Routes",
inputInitRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: localPeerKey,
Network: netip.MustParsePrefix("100.64.251.250/30"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
{
ID: "b",
NetID: "routeB",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("8.8.8.8/32"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputRoutes: []*route.Route{},
inputSerial: 1,
shouldCheckServerRoutes: true,
serverRoutesExpected: 0,
clientNetworkWatchersExpected: 0,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", iface.DefaultMTU)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
statusRecorder := status.NewRecorder()
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder)
defer routeManager.Stop()
if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
}
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes")
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
if testCase.shouldCheckServerRoutes {
require.Len(t, routeManager.serverRoutes, testCase.serverRoutesExpected, "server networks size should match")
}
})
}
}

View File

@@ -0,0 +1,27 @@
package routemanager
import (
"fmt"
"github.com/netbirdio/netbird/route"
)
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
StopFunc func()
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return fmt.Errorf("method UpdateRoutes is not implemented")
}
// Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop() {
if m.StopFunc != nil {
m.StopFunc()
}
}

View File

@@ -0,0 +1,384 @@
package routemanager
import (
"context"
"fmt"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"net"
"net/netip"
"sync"
)
import "github.com/google/nftables"
//
const (
nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd"
nftablesRoutingNatChain = "netbird-rt-nat"
)
// constants needed to create nftable rules
const (
ipv4Len = 4
ipv4SrcOffset = 12
ipv4DestOffset = 16
ipv6Len = 16
ipv6SrcOffset = 8
ipv6DestOffset = 24
exprDirectionSource = "source"
exprDirectionDestination = "destination"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)
exprAllowRelatedEstablished = []expr.Any{
&expr.Ct{
Register: 1,
SourceRegister: false,
Key: 0,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: []uint8{0x6, 0x0, 0x0, 0x0},
Xor: zeroXor,
},
&expr.Cmp{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
)
type nftablesManager struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule
mux sync.Mutex
}
// CleanRoutingRules cleans existing nftables rules from the system
func (n *nftablesManager) CleanRoutingRules() {
n.mux.Lock()
defer n.mux.Unlock()
log.Debug("flushing tables")
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}
// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
// if they don't exist, we create them
func (n *nftablesManager) RestoreOrCreateContainers() error {
n.mux.Lock()
defer n.mux.Unlock()
if n.tableIPv6 != nil && n.tableIPv4 != nil {
log.Debugf("nftables: containers already restored, skipping")
return nil
}
tables, err := n.conn.ListTables()
if err != nil {
return fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table
continue
}
n.tableIPv6 = table
}
}
if n.tableIPv4 == nil {
n.tableIPv4 = n.conn.AddTable(&nftables.Table{
Name: nftablesTable,
Family: nftables.TableFamilyIPv4,
})
}
if n.tableIPv6 == nil {
n.tableIPv6 = n.conn.AddTable(&nftables.Table{
Name: nftablesTable,
Family: nftables.TableFamilyIPv6,
})
}
chains, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables: unable to list chains: %v", err)
}
n.chains[ipv4] = make(map[string]*nftables.Chain)
n.chains[ipv6] = make(map[string]*nftables.Chain)
for _, chain := range chains {
switch {
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
n.chains[ipv4][chain.Name] = chain
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
n.chains[ipv6][chain.Name] = chain
}
}
if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingForwardingChain,
Table: n.tableIPv4,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityNATDest + 1,
Type: nftables.ChainTypeFilter,
})
}
if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingNatChain,
Table: n.tableIPv4,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
}
if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingForwardingChain,
Table: n.tableIPv6,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityNATDest + 1,
Type: nftables.ChainTypeFilter,
})
}
if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingNatChain,
Table: n.tableIPv6,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
}
err = n.refreshRulesMap()
if err != nil {
return err
}
n.checkOrCreateDefaultForwardingRules()
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (n *nftablesManager) refreshRulesMap() error {
for _, registeredChains := range n.chains {
for _, chain := range registeredChains {
rules, err := n.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
n.rules[string(rule.UserData)] = rule
}
}
}
}
return nil
}
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
_, foundIPv4 := n.rules[ipv4Forwarding]
if !foundIPv4 {
n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: exprAllowRelatedEstablished,
UserData: []byte(ipv4Forwarding),
})
}
_, foundIPv6 := n.rules[ipv6Forwarding]
if !foundIPv6 {
n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: exprAllowRelatedEstablished,
UserData: []byte(ipv6Forwarding),
})
}
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
fwdKey := genKey(forwardingFormat, pair.ID)
if prefix.Addr().Unmap().Is4() {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
} else {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
}
if pair.masquerade {
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
natKey := genKey(natFormat, pair.ID)
if prefix.Addr().Unmap().Is4() {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
} else {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
}
}
err := n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
fwdKey := genKey(forwardingFormat, pair.ID)
natKey := genKey(natFormat, pair.ID)
fwdRule, found := n.rules[fwdKey]
if found {
err = n.conn.DelRule(fwdRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
delete(n.rules, fwdKey)
}
natRule, found := n.rules[natKey]
if found {
err = n.conn.DelRule(natRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing nat rule for %s", pair.destination)
delete(n.rules, natKey)
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.destination)
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch {
case direction == exprDirectionSource && isIPv4:
return ipv4SrcOffset, ipv4Len, zeroXor
case direction == exprDirectionDestination && isIPv4:
return ipv4DestOffset, ipv4Len, zeroXor
case direction == exprDirectionSource && isIPv6:
return ipv6SrcOffset, ipv6Len, zeroXor6
case direction == exprDirectionDestination && isIPv6:
return ipv6DestOffset, ipv6Len, zeroXor6
default:
panic("no matched payload directive")
}
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: packetLen,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: packetLen,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -0,0 +1,270 @@
package routemanager
import (
"context"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
manager := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
pair := routerPair{
ID: "abc",
source: "100.100.100.1/32",
destination: "100.100.100.0/24",
masquerade: true,
}
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
forward4RuleKey := genKey(forwardingFormat, pair.ID)
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv4,
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forward4Exp,
UserData: []byte(forward4RuleKey),
})
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
nat4RuleKey := genKey(natFormat, pair.ID)
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv4,
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
Exprs: nat4Exp,
UserData: []byte(nat4RuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
pair = routerPair{
ID: "xyz",
source: "fc00::1/128",
destination: "fc11::/64",
masquerade: true,
}
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...)
forward6RuleKey := genKey(forwardingFormat, pair.ID)
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv6,
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forward6Exp,
UserData: []byte(forward6RuleKey),
})
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
nat6RuleKey := genKey(natFormat, pair.ID)
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv6,
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
Exprs: nat6Exp,
UserData: []byte(nat6RuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.tableIPv4 = nil
manager.tableIPv6 = nil
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
foundRule, found := manager.rules[forward4RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
foundRule, found = manager.rules[nat4RuleKey]
require.True(t, found, "nat rule should exist in the map")
// match len of output as nftables client doesn't return expressions with masquerade expression
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
foundRule, found = manager.rules[forward6RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
foundRule, found = manager.rules[nat6RuleKey]
require.True(t, found, "nat rule should exist in the map")
// match len of output as nftables client doesn't return expressions with masquerade expression
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
}
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
manager := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
testingExpression := append(sourceExp, destExp...)
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
manager := &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4
if testCase.ipVersion == ipv6 {
table = manager.tableIPv6
}
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...)
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.tableIPv4 = nil
manager.tableIPv6 = nil
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.inputPair)
require.NoError(t, err, "shouldn't return error")
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist")
}
}
}
}
})
}
}

View File

@@ -0,0 +1,67 @@
package routemanager
import (
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"net/netip"
"sync"
)
type serverRouter struct {
routes map[string]*route.Route
// best effort to keep net forward configuration as it was
netForwardHistoryEnabled bool
mux sync.Mutex
firewall firewallManager
}
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}
func (m *DefaultManager) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not removing from server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
if err != nil {
return err
}
delete(m.serverRouter.routes, route.ID)
return nil
}
}
func (m *DefaultManager) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
log.Infof("not adding to server network because context is done")
return m.ctx.Err()
default:
m.serverRouter.mux.Lock()
defer m.serverRouter.mux.Unlock()
err := m.serverRouter.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address.String(), route))
if err != nil {
return err
}
m.serverRouter.routes[route.ID] = route
return nil
}
}

View File

@@ -0,0 +1,55 @@
package routemanager
import (
"fmt"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
"net"
"net/netip"
)
var errRouteNotFound = fmt.Errorf("route not found")
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
if err != nil && err != errRouteNotFound {
return err
}
prefixGateway, err := getExistingRIBRouteGateway(prefix)
if err != nil && err != errRouteNotFound {
return err
}
if prefixGateway != nil && !prefixGateway.Equal(gateway) {
log.Warnf("route for network %s already exist and is pointing to the gateway: %s, won't add another one", prefix, prefixGateway)
return nil
}
return addToRouteTable(prefix, addr)
}
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
addrIP := net.ParseIP(addr)
prefixGateway, err := getExistingRIBRouteGateway(prefix)
if err != nil {
return err
}
if prefixGateway != nil && !prefixGateway.Equal(addrIP) {
log.Warnf("route for network %s is pointing to a different gateway: %s, should be pointing to: %s, not removing", prefix, prefixGateway, addrIP)
return nil
}
return removeFromRouteTable(prefix)
}
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
r, err := netroute.New()
if err != nil {
return nil, err
}
_, _, localGatewayAddress, err := r.Route(prefix.Addr().AsSlice())
if err != nil {
log.Errorf("getting routes returned an error: %v", err)
return nil, errRouteNotFound
}
return localGatewayAddress, nil
}

View File

@@ -0,0 +1,73 @@
package routemanager
import (
"github.com/vishvananda/netlink"
"io/ioutil"
"net"
"net/netip"
)
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
func addToRouteTable(prefix netip.Prefix, addr string) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return err
}
addrMask := "/32"
if prefix.Addr().Unmap().Is6() {
addrMask = "/128"
}
ip, _, err := net.ParseCIDR(addr + addrMask)
if err != nil {
return err
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Dst: ipNet,
Gw: ip,
}
err = netlink.RouteAdd(route)
if err != nil {
return err
}
return nil
}
func removeFromRouteTable(prefix netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return err
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Dst: ipNet,
}
err = netlink.RouteDel(route)
if err != nil {
return err
}
return nil
}
func enableIPForwarding() error {
err := ioutil.WriteFile(ipv4ForwardingPath, []byte("1"), 0644)
return err
}
func isNetForwardHistoryEnabled() bool {
out, err := ioutil.ReadFile(ipv4ForwardingPath)
if err != nil {
// todo
panic(err)
}
return string(out) == "1"
}

View File

@@ -0,0 +1,41 @@
//go:build !linux
// +build !linux
package routemanager
import (
log "github.com/sirupsen/logrus"
"net/netip"
"os/exec"
"runtime"
)
func addToRouteTable(prefix netip.Prefix, addr string) error {
cmd := exec.Command("route", "add", prefix.String(), addr)
out, err := cmd.Output()
if err != nil {
return err
}
log.Debugf(string(out))
return nil
}
func removeFromRouteTable(prefix netip.Prefix) error {
cmd := exec.Command("route", "delete", prefix.String())
out, err := cmd.Output()
if err != nil {
return err
}
log.Debugf(string(out))
return nil
}
func enableIPForwarding() error {
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func isNetForwardHistoryEnabled() bool {
log.Infof("check netforwad history is not implemented on %s", runtime.GOOS)
return false
}

View File

@@ -0,0 +1,68 @@
package routemanager
import (
"fmt"
"github.com/netbirdio/netbird/iface"
"github.com/stretchr/testify/require"
"net/netip"
"testing"
)
func TestAddRemoveRoutes(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
shouldRouteToWireguard bool
shouldBeRemoved bool
}{
{
name: "Should Add And Remove Route",
prefix: netip.MustParsePrefix("100.66.120.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
},
{
name: "Should Not Add Or Remove Route",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
shouldRouteToWireguard: false,
shouldBeRemoved: false,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", iface.DefaultMTU)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.GetAddress().IP.String())
require.NoError(t, err, "should not return err")
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "should not return err")
if testCase.shouldRouteToWireguard {
require.Equal(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
} else {
require.NotEqual(t, wgInterface.GetAddress().IP.String(), prefixGateway.String(), "route should point to a different interface")
}
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.GetAddress().IP.String())
require.NoError(t, err, "should not return err")
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
require.NoError(t, err, "should not return err")
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetGateway, prefixGateway, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetGateway, prefixGateway, "route should be pointing to a different gateway than the internet gateway")
}
})
}
}

View File

@@ -47,17 +47,19 @@ type FullStatus struct {
// Status holds a state of peers, signal and management connections
type Status struct {
mux sync.Mutex
peers map[string]PeerState
signal SignalState
management ManagementState
localPeer LocalPeerState
mux sync.Mutex
peers map[string]PeerState
changeNotify map[string]chan struct{}
signal SignalState
management ManagementState
localPeer LocalPeerState
}
// NewRecorder returns a new Status instance
func NewRecorder() *Status {
return &Status{
peers: make(map[string]PeerState),
peers: make(map[string]PeerState),
changeNotify: make(map[string]chan struct{}),
}
}
@@ -74,6 +76,18 @@ func (d *Status) AddPeer(peerPubKey string) error {
return nil
}
// GetPeer adds peer to Daemon status map
func (d *Status) GetPeer(peerPubKey string) (PeerState, error) {
d.mux.Lock()
defer d.mux.Unlock()
state, ok := d.peers[peerPubKey]
if !ok {
return PeerState{}, errors.New("peer not found")
}
return state, nil
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -113,9 +127,27 @@ func (d *Status) UpdatePeerState(receivedState PeerState) error {
d.peers[receivedState.PubKey] = peerState
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
return nil
}
// GetPeerStateChangeNotifier returns a change notifier channel for a peer
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()
ch, found := d.changeNotify[peer]
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
}
return ch
}
// UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock()

View File

@@ -19,6 +19,21 @@ func TestAddPeer(t *testing.T) {
assert.Error(t, err, "should return error on duplicate")
}
func TestGetPeer(t *testing.T) {
key := "abc"
status := NewRecorder()
err := status.AddPeer(key)
assert.NoError(t, err, "shouldn't return error")
peerStatus, err := status.GetPeer(key)
assert.NoError(t, err, "shouldn't return error on getting peer")
assert.Equal(t, key, peerStatus.PubKey, "retrieved public key should match")
_, err = status.GetPeer("non_existing_key")
assert.Error(t, err, "should return error when peer doesn't exist")
}
func TestUpdatePeerState(t *testing.T) {
key := "abc"
ip := "10.10.10.10"
@@ -39,6 +54,31 @@ func TestUpdatePeerState(t *testing.T) {
assert.Equal(t, ip, state.IP, "ip should be equal")
}
func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
key := "abc"
ip := "10.10.10.10"
status := NewRecorder()
peerState := PeerState{
PubKey: key,
}
status.peers[key] = peerState
ch := status.GetPeerStateChangeNotifier(key)
assert.NotNil(t, ch, "channel shouldn't be nil")
peerState.IP = ip
err := status.UpdatePeerState(peerState)
assert.NoError(t, err, "shouldn't return error")
select {
case <-ch:
default:
t.Errorf("channel wasn't closed after update")
}
}
func TestRemovePeer(t *testing.T) {
key := "abc"
status := NewRecorder()

6
go.mod
View File

@@ -30,10 +30,13 @@ require (
require (
fyne.io/fyne/v2 v2.1.4
github.com/c-robinson/iplib v1.0.3
github.com/coreos/go-iptables v0.6.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/v2 v2.3.1
github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/rs/xid v1.3.0
@@ -67,6 +70,7 @@ require (
github.com/godbus/dbus/v5 v5.0.4 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/google/go-cmp v0.5.7 // indirect
github.com/google/gopacket v1.1.19 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
@@ -114,3 +118,5 @@ require (
)
replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84

14
go.sum
View File

@@ -115,6 +115,8 @@ github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWH
github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/cncf/xds/go v0.0.0-20211130200136-a8f946100490/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
github.com/coocood/freecache v1.2.1 h1:/v1CqMq45NFH9mp/Pt142reundeBM0dVUD3osQBeu/U=
github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk=
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
@@ -283,10 +285,14 @@ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/martian/v3 v3.2.1/go.mod h1:oBOf6HBosgwRXnUGWUB05QECsc6uvmMiJ3+6W4l/CUk=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
@@ -383,8 +389,6 @@ github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/X
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 h1:oohm9Rk9JAxxmp2NLZa7Kebgz9h4+AJDcc64txg3dQ0=
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
@@ -401,6 +405,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE=
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
github.com/lyft/protoc-gen-star v0.5.3/go.mod h1:V0xaHgaf5oCCqmcxYcWiDfTiKsZsRc87/1qhoTACD8w=
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
@@ -466,6 +472,8 @@ github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8m
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84 h1:u8kpzR9ld1uAeH/BAXsS0SfcnhooLWeO7UgHSBVPD9I=
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -748,6 +756,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8=
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
@@ -870,6 +879,7 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -9,6 +9,16 @@ import (
"time"
)
// GetName returns the interface name
func (w *WGIface) GetName() string {
return w.Name
}
// GetAddress returns the interface address
func (w *WGIface) GetAddress() WGAddress {
return w.Address
}
// configureDevice configures the wireguard device
func (w *WGIface) configureDevice(config wgtypes.Config) error {
wg, err := wgctrl.New()
@@ -112,6 +122,114 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
return nil
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("adding allowed IP to interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP)
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while adding allowed Ip to peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP)
}
return nil
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
w.mu.Lock()
defer w.mu.Unlock()
log.Debugf("removing allowed IP from interface %s and peer %s: allowed IP %s ", w.Name, peerKey, allowedIP)
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
existingPeer, err := getPeer(w.Name, peerKey)
if err != nil {
return err
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...)
break
}
}
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = w.configureDevice(config)
if err != nil {
return fmt.Errorf("received error \"%v\" while removing allowed IP from peer on interface %s with settings: allowed ips %s", err, w.Name, allowedIP)
}
return nil
}
func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, err
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()

View File

@@ -229,7 +229,7 @@ func Test_UpdatePeer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
peer, err := getPeer(ifaceName, peerPubKey, t)
peer, err := getPeer(ifaceName, peerPubKey)
if err != nil {
t.Fatal(err)
}
@@ -289,7 +289,7 @@ func Test_RemovePeer(t *testing.T) {
if err != nil {
t.Fatal(err)
}
_, err = getPeer(ifaceName, peerPubKey, t)
_, err = getPeer(ifaceName, peerPubKey)
if err.Error() != "peer not found" {
t.Fatal(err)
}
@@ -378,7 +378,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatalf("waiting for peer handshake timeout after %s", timeout.String())
default:
}
peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String(), t)
peer, gpErr := getPeer(peer1ifaceName, peer2Key.PublicKey().String())
if gpErr != nil {
t.Fatal(gpErr)
}
@@ -389,28 +389,3 @@ func Test_ConnectPeers(t *testing.T) {
}
}
func getPeer(ifaceName, peerPubKey string, t *testing.T) (wgtypes.Peer, error) {
emptyPeer := wgtypes.Peer{}
wg, err := wgctrl.New()
if err != nil {
return emptyPeer, err
}
defer func() {
err = wg.Close()
if err != nil {
t.Error(err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return emptyPeer, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return emptyPeer, fmt.Errorf("peer not found")
}

View File

@@ -340,11 +340,6 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
u.Host, config.DeviceAuthorizationFlow.ProviderConfig.Domain)
config.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
}
if config.IdpManagerConfig != nil {
log.Infof("overriding Auth0ClientCredentials.AuthIssuer with a new value: %s, previously configured value: %s",
oidcConfig.Issuer, config.IdpManagerConfig.Auth0ClientCredentials.AuthIssuer)
config.IdpManagerConfig.Auth0ClientCredentials.AuthIssuer = oidcConfig.Issuer
}
}
return config, err

View File

@@ -31,14 +31,15 @@ const (
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
AddSetupKey(
CreateSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
) (*SetupKey, error)
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error)
@@ -75,6 +76,7 @@ type AccountManager interface {
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error)
ListSetupKeys(accountID string) ([]*SetupKey, error)
}
type DefaultAccountManager struct {
@@ -244,93 +246,6 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return nil
}
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
func (am *DefaultAccountManager) AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn time.Duration,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
keyDuration = expiresIn
}
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := GenerateSetupKey(keyName, keyType, keyDuration)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return setupKey, nil
}
// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Revoked = true
account.SetupKeys[keyCopy.Key] = keyCopy
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
// RenameSetupKey renames existing setup key of the specified account.
func (am *DefaultAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
setupKey := getAccountSetupKeyById(account, keyId)
if setupKey == nil {
return nil, status.Errorf(codes.NotFound, "unknown setupKey %s", keyId)
}
keyCopy := setupKey.Copy()
keyCopy.Name = newName
account.SetupKeys[keyCopy.Key] = keyCopy
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return keyCopy, nil
}
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
@@ -504,7 +419,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
@@ -688,7 +602,7 @@ func newAccountWithId(accountId, userId, domain string) *Account {
setupKeys := make(map[string]*SetupKey)
defaultKey := GenerateDefaultSetupKey()
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration)
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{})
setupKeys[defaultKey.Key] = defaultKey
setupKeys[oneOffKey.Key] = oneOffKey
network := NewNetwork()
@@ -713,15 +627,6 @@ func newAccountWithId(accountId, userId, domain string) *Account {
return acc
}
func getAccountSetupKeyById(acc *Account, keyId string) *SetupKey {
for _, k := range acc.SetupKeys {
if keyId == k.Id {
return k
}
}
return nil
}
func getAccountSetupKeyByKey(acc *Account, key string) *SetupKey {
for _, k := range acc.SetupKeys {
if key == k.Key {

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/netbirdio/netbird/route"
gPeer "google.golang.org/grpc/peer"
"strings"
"time"
@@ -88,17 +89,24 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, err := s.accountManager.GetPeer(peerKey.String())
if err != nil {
return status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered", peerKey.String())
p, _ := gPeer.FromContext(srv.Context())
msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return msg
}
syncReq := &proto.SyncRequest{}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid request message")
p, _ := gPeer.FromContext(srv.Context())
msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return msg
}
err = s.sendInitialSync(peerKey, peer, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
}
@@ -117,7 +125,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// condition when there are some updates
case update, open := <-updates:
if !open {
// updates channel has been closed
log.Debugf("updates channel for peer %s was closed", peerKey.String())
return nil
}
log.Debugf("recevied an update for peer %s", peerKey.String())
@@ -266,8 +274,13 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
// peer doesn't exist -> check if setup key was provided
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
// absent setup key -> permission denied
return nil, status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided", peerKey.String())
// absent setup key or jwt -> permission denied
p, _ := gPeer.FromContext(ctx)
msg := status.Errorf(codes.PermissionDenied,
"provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided,"+
" remote addr is %s", peerKey.String(), p.Addr.String())
log.Debug(msg)
return nil, msg
}
// setup key or jwt is present -> try normal registration flow

View File

@@ -134,6 +134,15 @@ components:
state:
description: Setup key status, "valid", "overused","expired" or "revoked"
type: string
auto_groups:
description: Setup key groups to auto-assign to peers registered with this key
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
updated_at:
description: Setup key last update date
type: string
format: date-time
required:
- id
- key
@@ -145,6 +154,8 @@ components:
- used_times
- last_used
- state
- auto_groups
- updated_at
SetupKeyRequest:
type: object
properties:
@@ -160,11 +171,17 @@ components:
revoked:
description: Setup key revocation status
type: boolean
auto_groups:
description: Setup key groups to auto-assign to peers registered with this key
type: array
items:
type: string
required:
- name
- type
- expires_in
- revoked
- auto_groups
GroupMinimum:
type: object
properties:

View File

@@ -299,6 +299,9 @@ type RulePatchOperationPath string
// SetupKey defines model for SetupKey.
type SetupKey struct {
// Setup key groups to auto-assign to peers registered with this key
AutoGroups []GroupMinimum `json:"auto_groups"`
// Setup Key expiration date
Expires time.Time `json:"expires"`
@@ -323,6 +326,9 @@ type SetupKey struct {
// Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// Usage count of setup key
UsedTimes int `json:"used_times"`
@@ -332,6 +338,9 @@ type SetupKey struct {
// SetupKeyRequest defines model for SetupKeyRequest.
type SetupKeyRequest struct {
// Setup key groups to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Expiration time in seconds
ExpiresIn int `json:"expires_in"`

View File

@@ -344,6 +344,14 @@ func peerIPsToKeys(account *server.Account, peerIPs *[]string) []string {
return mappedPeerKeys
}
func toGroupMinimumResponse(group *server.Group) *api.GroupMinimum {
return &api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
}
func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
cache := make(map[string]api.PeerMinimum)
gr := api.Group{

View File

@@ -39,12 +39,11 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys/{id}", keysHandler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/rules", rulesHandler.CreateRuleHandler).Methods("POST", "OPTIONS")

View File

@@ -348,6 +348,11 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
err = h.accountManager.DeleteRoute(account.Id, routeID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return

View File

@@ -78,7 +78,10 @@ func initRoutesTestData() *Routes {
SaveRouteFunc: func(_ string, _ *route.Route) error {
return nil
},
DeleteRouteFunc: func(_ string, _ string) error {
DeleteRouteFunc: func(_ string, peerIP string) error {
if peerIP != existingRouteID {
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
}
return nil
},
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
@@ -155,7 +158,7 @@ func TestRoutesHandlers(t *testing.T) {
{
name: "Get Not Existing Route",
requestType: http.MethodGet,
requestPath: "/api/rules/" + notFoundRouteID,
requestPath: "/api/routes/" + notFoundRouteID,
expectedStatus: http.StatusNotFound,
},
{
@@ -168,7 +171,7 @@ func TestRoutesHandlers(t *testing.T) {
{
name: "Delete Not Existing Route",
requestType: http.MethodDelete,
requestPath: "/api/rules/" + notFoundRouteID,
requestPath: "/api/routes/" + notFoundRouteID,
expectedStatus: http.StatusNotFound,
},
{

View File

@@ -2,6 +2,7 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -28,54 +29,17 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
}
}
func (h *SetupKeys) updateKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var key *server.SetupKey
if req.Revoked {
//handle only if being revoked, don't allow to enable key again for now
key, err = h.accountManager.RevokeSetupKey(accountId, keyId)
if err != nil {
http.Error(w, "failed revoking key", http.StatusInternalServerError)
return
}
}
if len(req.Name) != 0 {
key, err = h.accountManager.RenameSetupKey(accountId, keyId, req.Name)
if err != nil {
http.Error(w, "failed renaming key", http.StatusInternalServerError)
return
}
}
if key != nil {
writeSuccess(w, key)
}
}
func (h *SetupKeys) getKey(accountId string, keyId string, w http.ResponseWriter, r *http.Request) {
account, err := h.accountManager.GetAccountById(accountId)
if err != nil {
http.Error(w, "account doesn't exist", http.StatusInternalServerError)
return
}
for _, key := range account.SetupKeys {
if key.Id == keyId {
writeSuccess(w, key)
return
}
}
http.Error(w, "setup key not found", http.StatusNotFound)
}
func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.Request) {
req := &api.PostApiSetupKeysJSONRequestBody{}
err := json.NewDecoder(r.Body).Decode(&req)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -95,7 +59,13 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
expiresIn := time.Duration(req.ExpiresIn) * time.Second
setupKey, err := h.accountManager.AddSetupKey(accountId, req.Name, server.SetupKeyType(req.Type), expiresIn)
if req.AutoGroups == nil {
req.AutoGroups = []string{}
}
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
@@ -109,7 +79,8 @@ func (h *SetupKeys) createKey(accountId string, w http.ResponseWriter, r *http.R
writeSuccess(w, setupKey)
}
func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
@@ -118,25 +89,84 @@ func (h *SetupKeys) HandleKey(w http.ResponseWriter, r *http.Request) {
}
vars := mux.Vars(r)
keyId := vars["id"]
if len(keyId) == 0 {
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodPut:
h.updateKey(account.Id, keyId, w, r)
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
case http.MethodGet:
h.getKey(account.Id, keyId, w, r)
return
default:
http.Error(w, "", http.StatusNotFound)
}
writeSuccess(w, key)
}
func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
vars := mux.Vars(r)
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
return
}
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name == "" {
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
return
}
if req.AutoGroups == nil {
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
return
}
newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
default:
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
}
}
return
}
writeSuccess(w, newKey)
}
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
@@ -145,41 +175,40 @@ func (h *SetupKeys) GetKeys(w http.ResponseWriter, r *http.Request) {
return
}
switch r.Method {
case http.MethodPost:
h.createKey(account.Id, w, r)
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
case http.MethodGet:
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
respBody := []*api.SetupKey{}
for _, key := range account.SetupKeys {
respBody = append(respBody, toResponseBody(key))
}
err = json.NewEncoder(w).Encode(respBody)
if err != nil {
log.Errorf("failed encoding account peers %s: %v", account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
default:
http.Error(w, "", http.StatusNotFound)
}
groups, err := h.accountManager.ListGroups(account.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(groups, key))
}
writeJSONObject(w, apiSetupKeys)
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key))
err := json.NewEncoder(w).Encode(toResponseBody(nil, key))
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
return
}
}
func toResponseBody(key *server.SetupKey) *api.SetupKey {
// toResponseBody takes a list of all groups, a key, finds groups that belong to the key and add them to the response
func toResponseBody(groups []*server.Group, key *server.SetupKey) *api.SetupKey {
var state string
if key.IsExpired() {
state = "expired"
@@ -190,16 +219,30 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey {
} else {
state = "valid"
}
// should be instantiated like that to ensure if empty, we return an empty array to the client
autoGroups := []api.GroupMinimum{}
for _, group := range groups {
for _, keyGroup := range key.AutoGroups {
if group.ID == keyGroup {
autoGroups = append(autoGroups, *toGroupMinimumResponse(group))
break
}
}
}
return &api.SetupKey{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
State: state,
Id: key.Id,
Key: key.Key,
Name: key.Name,
Expires: key.ExpiresAt,
Type: string(key.Type),
Valid: key.IsValid(),
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
State: state,
AutoGroups: autoGroups,
UpdatedAt: key.UpdatedAt,
}
}

View File

@@ -0,0 +1,222 @@
package http
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
existingSetupKeyID = "existingSetupKeyID"
newSetupKeyName = "New Setup Key"
updatedSetupKeyName = "KKKey"
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
}, nil
},
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
return newKey, nil
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
case newKey.Id:
return newKey, nil
default:
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
}
},
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testAccountID,
}
},
},
}
}
func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
updatedDefaultSetupKey.Name = updatedSetupKeyName
updatedDefaultSetupKey.Revoked = true
tt := []struct {
name string
requestType string
requestPath string
requestBody io.Reader
expectedStatus int
expectedBody bool
expectedSetupKey *api.SetupKey
expectedSetupKeys []*api.SetupKey
}{
{
name: "Get Setup Keys",
requestType: http.MethodGet,
requestPath: "/api/setup-keys",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKeys: []*api.SetupKey{toResponseBody(nil, defaultSetupKey)},
},
{
name: "Get Existing Setup Key",
requestType: http.MethodGet,
requestPath: "/api/setup-keys/" + existingSetupKeyID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(nil, defaultSetupKey),
},
{
name: "Get Not Existing Setup Key",
requestType: http.MethodGet,
requestPath: "/api/setup-keys/" + notFoundSetupKeyID,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "Create Setup Key",
requestType: http.MethodPost,
requestPath: "/api/setup-keys",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\"}", newSetupKey.Name, newSetupKey.Type))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(nil, newSetupKey),
},
{
name: "Update Setup Key",
requestType: http.MethodPut,
requestPath: "/api/setup-keys/" + defaultSetupKey.Id,
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"name\":\"%s\",\"auto_groups\":[\"%s\"], \"revoked\":%v}",
updatedDefaultSetupKey.Type,
updatedDefaultSetupKey.AutoGroups[0],
updatedDefaultSetupKey.Revoked,
))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedSetupKey: toResponseBody(nil, updatedDefaultSetupKey),
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.GetSetupKeyHandler).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{id}", handler.UpdateSetupKeyHandler).Methods("PUT", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
if tc.expectedSetupKey != nil {
got := &api.SetupKey{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assertKeys(t, got, tc.expectedSetupKey)
return
}
if len(tc.expectedSetupKeys) > 0 {
var got []*api.SetupKey
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assertKeys(t, got[0], tc.expectedSetupKeys[0])
return
}
})
}
}
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
// this comparison is done manually because when converting to JSON dates formatted differently
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
assert.Equal(t, got.Name, expected.Name)
assert.Equal(t, got.Id, expected.Id)
assert.Equal(t, got.Key, expected.Key)
assert.Equal(t, got.Type, expected.Type)
assert.Equal(t, got.UsedTimes, expected.UsedTimes)
assert.Equal(t, got.Revoked, expected.Revoked)
assert.ElementsMatch(t, got.AutoGroups, expected.AutoGroups)
}

View File

@@ -12,9 +12,8 @@ import (
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error)
AddSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration) (*server.SetupKey, error)
RevokeSetupKeyFunc func(accountId string, keyId string) (*server.SetupKey, error)
RenameSetupKeyFunc func(accountId string, keyId string, newName string) (*server.SetupKey, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
@@ -51,6 +50,8 @@ type MockAccountManager struct {
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
DeleteRouteFunc func(accountID, routeID string) error
ListRoutesFunc func(accountID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@@ -82,40 +83,18 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
}
// AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface
func (am *MockAccountManager) AddSetupKey(
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
func (am *MockAccountManager) CreateSetupKey(
accountId string,
keyName string,
keyType server.SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
) (*server.SetupKey, error) {
if am.AddSetupKeyFunc != nil {
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
if am.CreateSetupKeyFunc != nil {
return am.CreateSetupKeyFunc(accountId, keyName, keyType, expiresIn, autoGroups)
}
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey is not implemented")
}
// RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface
func (am *MockAccountManager) RevokeSetupKey(
accountId string,
keyId string,
) (*server.SetupKey, error) {
if am.RevokeSetupKeyFunc != nil {
return am.RevokeSetupKeyFunc(accountId, keyId)
}
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey is not implemented")
}
// RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface
func (am *MockAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*server.SetupKey, error) {
if am.RenameSetupKeyFunc != nil {
return am.RenameSetupKeyFunc(accountId, keyId, newName)
}
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface
@@ -415,3 +394,30 @@ func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, erro
}
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
}
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey) (*server.SetupKey, error) {
if am.SaveSetupKeyFunc != nil {
return am.SaveSetupKeyFunc(accountID, key)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, keyID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
}

View File

@@ -390,6 +390,11 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string)
return err
}
if peer.SSHKey == sshKey {
log.Debugf("same SSH key provided for peer %s, skipping update", peerKey)
return nil
}
account, err := am.Store.GetPeerAccount(peerKey)
if err != nil {
return err

View File

@@ -1,7 +1,10 @@
package server
import (
"fmt"
"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"hash/fnv"
"strconv"
"strings"
@@ -18,8 +21,41 @@ const (
DefaultSetupKeyDuration = 24 * 30 * time.Hour
// DefaultSetupKeyName is a default name of the default setup key
DefaultSetupKeyName = "Default key"
// UpdateSetupKeyName indicates a setup key name update operation
UpdateSetupKeyName SetupKeyUpdateOperationType = iota
// UpdateSetupKeyRevoked indicates a setup key revoked filed update operation
UpdateSetupKeyRevoked
// UpdateSetupKeyAutoGroups indicates a setup key auto-assign groups update operation
UpdateSetupKeyAutoGroups
// UpdateSetupKeyExpiresAt indicates a setup key expiration time update operation
UpdateSetupKeyExpiresAt
)
// SetupKeyUpdateOperationType operation type
type SetupKeyUpdateOperationType int
func (t SetupKeyUpdateOperationType) String() string {
switch t {
case UpdateSetupKeyName:
return "UpdateSetupKeyName"
case UpdateSetupKeyRevoked:
return "UpdateSetupKeyRevoked"
case UpdateSetupKeyAutoGroups:
return "UpdateSetupKeyAutoGroups"
case UpdateSetupKeyExpiresAt:
return "UpdateSetupKeyExpiresAt"
default:
return "InvalidOperation"
}
}
// SetupKeyUpdateOperation operation object with type and values to be applied
type SetupKeyUpdateOperation struct {
Type SetupKeyUpdateOperationType
Values []string
}
// SetupKeyType is the type of setup key
type SetupKeyType string
@@ -31,30 +67,38 @@ type SetupKey struct {
Type SetupKeyType
CreatedAt time.Time
ExpiresAt time.Time
UpdatedAt time.Time
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
Revoked bool
// UsedTimes indicates how many times the key was used
UsedTimes int
// LastUsed last time the key was used for peer registration
LastUsed time.Time
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
AutoGroups []string
}
//Copy copies SetupKey to a new object
// Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
return &SetupKey{
Id: key.Id,
Key: key.Key,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
Id: key.Id,
Key: key.Key,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
UpdatedAt: key.UpdatedAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
AutoGroups: key.AutoGroups,
}
}
//IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
c.UsedTimes = c.UsedTimes + 1
@@ -83,24 +127,25 @@ func (key *SetupKey) IsOverUsed() bool {
}
// GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration) *SetupKey {
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string) *SetupKey {
key := strings.ToUpper(uuid.New().String())
createdAt := time.Now()
return &SetupKey{
Id: strconv.Itoa(int(Hash(key))),
Key: key,
Name: name,
Type: t,
CreatedAt: createdAt,
ExpiresAt: createdAt.Add(validFor),
Revoked: false,
UsedTimes: 0,
Id: strconv.Itoa(int(Hash(key))),
Key: key,
Name: name,
Type: t,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(validFor),
UpdatedAt: time.Now(),
Revoked: false,
UsedTimes: 0,
AutoGroups: autoGroups,
}
}
// GenerateDefaultSetupKey generates a default setup key
func GenerateDefaultSetupKey() *SetupKey {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration)
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{})
}
func Hash(s string) uint32 {
@@ -111,3 +156,127 @@ func Hash(s string) uint32 {
}
return h.Sum32()
}
// CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key,
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
keyDuration := DefaultSetupKeyDuration
if expiresIn != 0 {
keyDuration = expiresIn
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
for _, group := range autoGroups {
if _, ok := account.Groups[group]; !ok {
return nil, fmt.Errorf("group %s doesn't exist", group)
}
}
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
}
return setupKey, nil
}
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
// Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
var oldKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyToSave.Id {
oldKey = key.Copy()
break
}
}
if oldKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found")
}
// only auto groups, revoked status, and name can be updated for now
newKey := oldKey.Copy()
newKey.Name = keyToSave.Name
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now()
account.SetupKeys[newKey.Key] = newKey
if err = am.Store.SaveAccount(account); err != nil {
return nil, err
}
return newKey, am.updateAccountPeers(account)
}
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
keys := make([]*SetupKey, 0, len(account.SetupKeys))
for _, key := range account.SetupKeys {
keys = append(keys, key.Copy())
}
return keys, nil
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
var foundKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyID {
foundKey = key.Copy()
break
}
}
if foundKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found")
}
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
if foundKey.UpdatedAt.IsZero() {
foundKey.UpdatedAt = foundKey.CreatedAt
}
return foundKey, nil
}

View File

@@ -2,23 +2,159 @@ package server
import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"strconv"
"testing"
"time"
)
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "test_user"
account, err := manager.GetOrCreateAccountByUser(userID, "")
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(account.Id, &Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
expiresIn := time.Hour
keyName := "my-test-key"
key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{})
if err != nil {
t.Fatal(err)
}
autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key"
revoked := true
newKey, err := manager.SaveSetupKey(account.Id, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
})
if err != nil {
t.Fatal(err)
}
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now(), autoGroups)
}
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
userID := "test_user"
account, err := manager.GetOrCreateAccountByUser(userID, "")
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(account.Id, &Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(account.Id, &Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
type testCase struct {
name string
expectedKeyName string
expectedUsedTimes int
expectedType string
expectedGroups []string
expectedCreatedAt time.Time
expectedUpdatedAt time.Time
expectedExpiresAt time.Time
expectedFailure bool //indicates whether key creation should fail
}
now := time.Now()
expiresIn := time.Hour
testCase1 := testCase{
name: "Should Create Setup Key successfully",
expectedKeyName: "my-test-key",
expectedUsedTimes: 0,
expectedType: "reusable",
expectedGroups: []string{"group_1", "group_2"},
expectedCreatedAt: now,
expectedUpdatedAt: now,
expectedExpiresAt: now.Add(expiresIn),
expectedFailure: false,
}
testCase2 := testCase{
name: "Create Setup Key should fail because of unexistent group",
expectedKeyName: "my-test-key",
expectedGroups: []string{"FAKE"},
expectedFailure: true,
}
for _, tCase := range []testCase{testCase1, testCase2} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups)
if tCase.expectedFailure {
if err == nil {
t.Fatal("expected to fail")
}
return
}
if err != nil {
t.Fatal(err)
}
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
tCase.expectedUpdatedAt, tCase.expectedGroups)
})
}
}
func TestGenerateDefaultSetupKey(t *testing.T) {
expectedName := "Default key"
expectedRevoke := false
expectedType := "reusable"
expectedUsedTimes := 0
expectedCreatedAt := time.Now()
expectedUpdatedAt := time.Now()
expectedExpiresAt := time.Now().Add(24 * 30 * time.Hour)
var expectedAutoGroups []string
key := GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
}
@@ -29,41 +165,44 @@ func TestGenerateSetupKey(t *testing.T) {
expectedUsedTimes := 0
expectedCreatedAt := time.Now()
expectedExpiresAt := time.Now().Add(time.Hour)
expectedUpdatedAt := time.Now()
var expectedAutoGroups []string
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour)
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{})
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))))
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups)
}
func TestSetupKey_IsValid(t *testing.T) {
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour)
validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{})
if !validKey.IsValid() {
t.Errorf("expected key to be valid, got invalid %v", validKey)
}
// expired
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour)
expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{})
if expiredKey.IsValid() {
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
}
// revoked
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
revokedKey.Revoked = true
if revokedKey.IsValid() {
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
}
// overused
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour)
overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{})
overUsedKey.UsedTimes = 1
if overUsedKey.IsValid() {
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
}
// overused
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour)
reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{})
reusableKey.UsedTimes = 99
if !reusableKey.IsValid() {
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
@@ -71,7 +210,8 @@ func TestSetupKey_IsValid(t *testing.T) {
}
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string) {
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
expectedUpdatedAt time.Time, expectedAutoGroups []string) {
if key.Name != expectedName {
t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name)
}
@@ -92,6 +232,10 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt)
}
if key.UpdatedAt.Sub(expectedUpdatedAt).Round(time.Hour) != 0 {
t.Errorf("expected setup key to have UpdatedAt ~ %v, got %v", expectedUpdatedAt, key.UpdatedAt)
}
if key.CreatedAt.Sub(expectedCreatedAt).Round(time.Hour) != 0 {
t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt)
}
@@ -104,13 +248,19 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke
if key.Id != strconv.Itoa(int(Hash(key.Key))) {
t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id)
}
if len(key.AutoGroups) != len(expectedAutoGroups) {
t.Errorf("expected key AutoGroups size=%d, got %d", len(expectedAutoGroups), len(key.AutoGroups))
}
assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal")
}
func TestSetupKey_Copy(t *testing.T) {
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour)
key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{})
keyCopy := key.Copy()
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id)
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
key.UpdatedAt, key.AutoGroups)
}

View File

@@ -85,15 +85,18 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerKey string) {
m.cancel(peerKey)
cancel := make(chan struct{}, 1)
m.cancelMap[peerKey] = cancel
log.Debugf("starting turn refresh for %s", peerKey)
go func() {
//we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
for {
select {
case <-cancel:
log.Debugf("stopping turn refresh for %s", peerKey)
return
default:
//we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
time.Sleep(m.config.CredentialsTTL.Duration / 4 * 3)
case <-ticker.C:
c := m.GenerateCredentials()
var turns []*proto.ProtectedHostConfig
for _, host := range m.config.Turns {

View File

@@ -6,6 +6,8 @@ import (
"sync"
)
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
}
@@ -28,7 +30,12 @@ func (p *PeersUpdateManager) SendUpdate(peer string, update *UpdateMessage) erro
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
if channel, ok := p.peerChannels[peer]; ok {
channel <- update
select {
case channel <- update:
log.Infof("update was sent to channel for peer %s", peer)
default:
log.Warnf("channel for peer %s is %d full", peer, len(channel))
}
return nil
}
log.Debugf("peer %s has no channel", peer)
@@ -45,7 +52,7 @@ func (p *PeersUpdateManager) CreateChannel(peerKey string) chan *UpdateMessage {
close(channel)
}
//mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, 100)
channel := make(chan *UpdateMessage, channelBufferSize)
p.peerChannels[peerKey] = channel
log.Debugf("opened updates channel for a peer %s", peerKey)

View File

@@ -3,13 +3,14 @@ package server
import (
"github.com/netbirdio/netbird/management/proto"
"testing"
"time"
)
var peersUpdater *PeersUpdateManager
//var peersUpdater *PeersUpdateManager
func TestCreateChannel(t *testing.T) {
peer := "test-create"
peersUpdater = NewPeersUpdateManager()
peersUpdater := NewPeersUpdateManager()
defer peersUpdater.CloseChannel(peer)
_ = peersUpdater.CreateChannel(peer)
@@ -20,12 +21,17 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
update := &UpdateMessage{Update: &proto.SyncResponse{}}
peersUpdater := NewPeersUpdateManager()
update1 := &UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
}}
_ = peersUpdater.CreateChannel(peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
err := peersUpdater.SendUpdate(peer, update)
err := peersUpdater.SendUpdate(peer, update1)
if err != nil {
t.Error("Error sending update: ", err)
}
@@ -34,10 +40,41 @@ func TestSendUpdate(t *testing.T) {
default:
t.Error("Update wasn't send")
}
for range [channelBufferSize]int{} {
err = peersUpdater.SendUpdate(peer, update1)
if err != nil {
t.Errorf("got an early error sending update: %v ", err)
}
}
update2 := &UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
}}
err = peersUpdater.SendUpdate(peer, update2)
if err != nil {
t.Error("update shouldn't return an error when channel buffer is full")
}
timeout := time.After(5 * time.Second)
for range [channelBufferSize]int{} {
select {
case <-timeout:
t.Error("timed out reading previously sent updates")
case updateReader := <-peersUpdater.peerChannels[peer]:
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
t.Error("got the update that shouldn't have been sent")
}
}
}
}
func TestCloseChannel(t *testing.T) {
peer := "test-close"
peersUpdater := NewPeersUpdateManager()
_ = peersUpdater.CreateChannel(peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")

View File

@@ -2,6 +2,7 @@ package client
import (
"fmt"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/signal/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"io"
@@ -41,13 +42,15 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
}
// MarshalCredential marsharl a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type) (*proto.Message, error) {
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
Body: &proto.Body{
Type: t,
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
Type: t,
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
WgListenPort: uint32(myPort),
NetBirdVersion: system.NetbirdVersion(),
},
}, nil
}

4
signal/proto/generate.sh Executable file
View File

@@ -0,0 +1,4 @@
#!/bin/bash
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I proto/ proto/signalexchange.proto --go_out=. --go-grpc_out=.

View File

@@ -214,6 +214,9 @@ type Body struct {
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
// wgListenPort is an actual WireGuard listen port
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
}
func (x *Body) Reset() {
@@ -262,6 +265,20 @@ func (x *Body) GetPayload() string {
return ""
}
func (x *Body) GetWgListenPort() uint32 {
if x != nil {
return x.WgListenPort
}
return 0
}
func (x *Body) GetNetBirdVersion() string {
if x != nil {
return x.NetBirdVersion
}
return ""
}
var File_signalexchange_proto protoreflect.FileDescriptor
var file_signalexchange_proto_rawDesc = []byte{
@@ -281,28 +298,32 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0x7d, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a,
0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64,
0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, 0x07,
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70,
0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x2c, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12,
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74,
0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01,
0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xc9, 0x01, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x77, 0x67, 0x4c, 0x69, 0x73,
0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x77,
0x67, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x22, 0x2c, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f,
0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52,
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10,
0x02, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68,
0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72,
0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68,
0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a,
0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (

View File

@@ -49,4 +49,7 @@ message Body {
}
Type type = 1;
string payload = 2;
// wgListenPort is an actual WireGuard listen port
uint32 wgListenPort = 3;
string netBirdVersion = 4;
}