Compare commits

..

7 Commits

Author SHA1 Message Date
Zoltán Papp
3cb519c650 Revert WG changes and use the local proxy 2024-04-23 16:07:40 +02:00
Zoltán Papp
318d379658 Update wireguard-go 2024-04-22 19:13:07 +02:00
Zoltán Papp
b68a02acee Close turn connection
Without it the WG can not exit from the read loop
2024-04-18 15:59:50 +02:00
Zoltán Papp
b5c4802bb9 Apply new receiver functions 2024-04-16 16:01:25 +02:00
Zoltán Papp
28a9a2ef87 Move receiverCreator to new file 2024-04-15 14:12:24 +02:00
Zoltán Papp
b355c34b63 Configure wg with proper address 2024-04-12 18:29:35 +02:00
Zoltán Papp
d67f766b2e Initial code 2024-04-12 17:38:31 +02:00
53 changed files with 881 additions and 785 deletions

View File

@@ -33,10 +33,6 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"io"
"io/fs"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path"
@@ -145,6 +147,13 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
go func() {
// Start the HTTP server on port 8080
http.ListenAndServe("localhost:8080", nil)
}()
// Your application code here
}
// SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -143,7 +143,7 @@ func (m *Manager) AllowNetbird() error {
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
log.Debugf("allow netbird rule already exists: %#v", rule)
return nil
}

View File

@@ -138,6 +138,7 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
turnRelay *relay.PermanentTurn
}
// Peer is an instance of the Connection Peer
@@ -199,7 +200,7 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
wgProxyFactory: &wgproxy.Factory{},
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
@@ -452,10 +453,19 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
t = sProto.Body_OFFER
}
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
msg, err := signal.MarshalCredential(
myKey,
offerAnswer.WgListenPort,
remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
t,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelayedAddr.String(),
offerAnswer.RemoteAddr.String(),
)
if err != nil {
return err
}
@@ -483,6 +493,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
turnRelay := relay.NewPermanentTurn(e.STUNs[0], e.TURNs[0])
err = turnRelay.Open()
if err != nil {
return fmt.Errorf("faile to open turn relay: %w", err)
}
e.turnRelay = turnRelay
//e.wgInterface.SetRelayConn(e.turnRelay.RelayConn())
// todo update signal
}
@@ -621,6 +639,7 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
var newTURNs []*stun.URI
log.Debugf("got TURNs update from Management Service, updating")
for _, turn := range turns {
log.Debugf("-----updated Turn %v, %s, %s", turn.HostConfig.Uri, turn.User, turn.Password)
url, err := stun.ParseURI(turn.HostConfig.Uri)
if err != nil {
return err
@@ -934,7 +953,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
RosenpassAddr: e.getRosenpassAddr(),
}
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover, e.turnRelay)
if err != nil {
return nil, err
}
@@ -1000,6 +1019,17 @@ func (e *Engine) receiveSignalEvents() {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
if err != nil {
return err
}
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -1009,6 +1039,8 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelayedAddr: relayedAddr,
RemoteAddr: remoteAddr,
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1024,6 +1056,17 @@ func (e *Engine) receiveSignalEvents() {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
if err != nil {
return err
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -1033,6 +1076,8 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelayedAddr: relayedAddr,
RemoteAddr: remoteAddr,
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1043,7 +1088,6 @@ func (e *Engine) receiveSignalEvents() {
conn.OnRemoteCandidate(candidate)
case sProto.Body_MODE:
}
return nil
})
if err != nil {
@@ -1115,6 +1159,8 @@ func (e *Engine) close() {
log.Errorf("failed closing ebpf proxy: %s", err)
}
e.turnRelay.Close()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"time"
@@ -14,6 +13,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
@@ -93,6 +93,10 @@ type OfferAnswer struct {
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
// This value is the local Rosenpass server address when sending the message
RosenpassAddr string
// Turn Relay
RelayedAddr net.Addr
RemoteAddr net.Addr
}
// IceCredentials ICE protocol credentials struct
@@ -131,7 +135,7 @@ type Conn struct {
statusRecorder *Status
wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy
wgProxy *wgproxy.WGUserSpaceProxy
remoteModeCh chan ModeMessage
meta meta
@@ -141,11 +145,11 @@ type Conn struct {
sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
turnRelay *relay.PermanentTurn
}
// meta holds meta information about a connection
@@ -176,7 +180,7 @@ func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, turnRelay *relay.PermanentTurn) (*Conn, error) {
return &Conn{
config: config,
mu: sync.Mutex{},
@@ -189,6 +193,7 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
wgProxyFactory: wgProxyFactory,
adapter: adapter,
iFaceDiscover: iFaceDiscover,
turnRelay: turnRelay,
}, nil
}
@@ -212,7 +217,7 @@ func (conn *Conn) reCreateAgent() error {
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: conn.candidateTypes(),
CandidateTypes: []ice.CandidateType{},
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
@@ -262,17 +267,6 @@ func (conn *Conn) reCreateAgent() error {
return nil
}
func (conn *Conn) candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
// Open opens connection to the remote peer starting ICE candidate gathering process.
// Blocks until connection has been closed or connection timeout.
// ConnStatus will be set accordingly
@@ -351,42 +345,56 @@ func (conn *Conn) Open() error {
log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
}
err = conn.agent.GatherCandidates()
if err != nil {
return err
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
isControlling := conn.config.LocalKey > conn.config.Key
var remoteConn *ice.Conn
isControlling := conn.config.LocalKey < conn.config.Key
if isControlling {
remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
log.Debugf("send punchole to: %s", remoteOfferAnswer.RemoteAddr.String())
err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr)
if err != nil {
log.Errorf("failed to punch hole: %v", err)
}
addr, ok := remoteOfferAnswer.RemoteAddr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("failed to cast addr to udp addr")
}
addr.Port = remoteOfferAnswer.WgListenPort
conn.wgProxy = wgproxy.NewWGUserSpaceProxy(conn.config.LocalWgPort)
myNetConn := NewMyNetConn(conn.turnRelay.RelayConn(), addr)
endpoint, err := conn.wgProxy.AddTurnConn(myNetConn)
if err != nil {
return err
}
proxyedAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
log.Debugf("---- use this peer's tunr connection: %s", addr)
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, proxyedAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
} else {
remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
if err != nil {
return err
}
addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("failed to cast addr to udp addr")
}
log.Debugf("---- use remote peer tunr connection: %s", addr)
// dynamically set remote WireGuard port is other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, addr.String())
}
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
if err != nil {
return err
}
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select {
case <-conn.closeCh:
@@ -415,25 +423,8 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
conn.mu.Lock()
defer conn.mu.Unlock()
pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil {
return nil, err
}
var endpoint net.Addr
if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
endpoint = remoteConn.RemoteAddr()
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
@@ -445,7 +436,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
}
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
@@ -454,31 +445,33 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
conn.status = StatusConnected
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
/*
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
*/
_, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
if err != nil {
@@ -680,6 +673,8 @@ func (conn *Conn) sendAnswer() error {
Version: version.NetbirdVersion(),
RosenpassPubKey: conn.config.RosenpassPubKey,
RosenpassAddr: conn.config.RosenpassAddr,
RelayedAddr: conn.turnRelay.RelayedAddress(),
RemoteAddr: conn.turnRelay.SrvRefAddr(),
})
if err != nil {
return err
@@ -703,6 +698,8 @@ func (conn *Conn) sendOffer() error {
Version: version.NetbirdVersion(),
RosenpassPubKey: conn.config.RosenpassPubKey,
RosenpassAddr: conn.config.RosenpassAddr,
RelayedAddr: conn.turnRelay.RelayedAddress(),
RemoteAddr: conn.turnRelay.SrvRefAddr(),
})
if err != nil {
return err
@@ -742,6 +739,10 @@ func (conn *Conn) Status() ConnStatus {
return conn.status
}
func (conn *Conn) OnRemoteRelayRequest(relayedAddr string, remoteIP string) {
}
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {

View File

@@ -0,0 +1,52 @@
package peer
import (
"net"
"time"
)
type MyNetConn struct {
remoteConn net.PacketConn
remoteAddr net.Addr
}
func NewMyNetConn(remoteConn net.PacketConn, remoteAddr net.Addr) net.Conn {
return &MyNetConn{
remoteConn: remoteConn,
remoteAddr: remoteAddr,
}
}
func (m *MyNetConn) Read(b []byte) (n int, err error) {
n, _, err = m.remoteConn.ReadFrom(b)
return
}
func (m *MyNetConn) Write(b []byte) (n int, err error) {
n, err = m.remoteConn.WriteTo(b, m.remoteAddr)
return
}
func (m *MyNetConn) Close() error {
return m.remoteConn.Close()
}
func (m *MyNetConn) LocalAddr() net.Addr {
return m.remoteConn.LocalAddr()
}
func (m *MyNetConn) RemoteAddr() net.Addr {
return m.remoteAddr
}
func (m *MyNetConn) SetDeadline(t time.Time) error {
return m.remoteConn.SetDeadline(t)
}
func (m *MyNetConn) SetReadDeadline(t time.Time) error {
return m.remoteConn.SetReadDeadline(t)
}
func (m *MyNetConn) SetWriteDeadline(t time.Time) error {
return m.remoteConn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,152 @@
package relay
import (
"fmt"
"math"
"net"
"sync"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
)
type PermanentTurn struct {
stunURI *stun.URI
turnURI *stun.URI
stunConn net.PacketConn
turnClient *turn.Client
turnClientListenLock sync.Mutex
relayConn net.PacketConn // represents the remote socket.
srvReflexiveAddress *net.UDPAddr
}
func NewPermanentTurn(stunURL, turnURL *stun.URI) *PermanentTurn {
return &PermanentTurn{
stunURI: stunURL,
turnURI: turnURL,
}
}
func (r *PermanentTurn) Open() error {
log.Debugf("Opening permanent turn connection")
stunConn, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
return err
}
r.stunConn = stunConn
cfg := &turn.ClientConfig{
STUNServerAddr: toURL(r.stunURI),
TURNServerAddr: toURL(r.turnURI),
Conn: stunConn,
Username: r.turnURI.Username,
Password: r.turnURI.Password,
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
client, err := turn.NewClient(cfg)
if err != nil {
log.Errorf("failed to create turn client: %v", err)
return err
}
r.turnClient = client
err = r.turnClient.Listen()
if err != nil {
log.Errorf("failed to listen: %v", err)
}
//r.listen()
relayConn, err := client.Allocate()
if err != nil {
log.Errorf("failed to allocate relay connection: %v", err)
return err
}
r.relayConn = relayConn
srvReflexiveAddress, err := r.discoverPublicIP()
if err != nil {
log.Errorf("failed to discover public IP: %v", err)
return err
}
r.srvReflexiveAddress = srvReflexiveAddress
return nil
}
func (r *PermanentTurn) RelayedAddress() net.Addr {
return r.relayConn.LocalAddr()
}
func (r *PermanentTurn) SrvRefAddr() net.Addr {
return r.srvReflexiveAddress
}
func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error {
_, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr)
return err
}
func (r *PermanentTurn) RelayConn() net.PacketConn {
return r.relayConn
}
func (r *PermanentTurn) Close() {
r.turnClient.Close()
err := r.relayConn.Close()
if err != nil {
log.Errorf("failed to close relayConn: %s", err.Error())
}
err = r.stunConn.Close()
if err != nil {
log.Errorf("failed to close stunConn: %s", err.Error())
}
}
func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) {
addr, err := r.turnClient.SendBindingRequest()
if err != nil {
log.Errorf("failed to send binding request: %v", err)
return nil, err
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return nil, fmt.Errorf("failed to cast addr to udp addr")
}
return udpAddr, nil
}
func (r *PermanentTurn) listen() {
if !r.turnClientListenLock.TryLock() {
return
}
go func() {
defer r.turnClientListenLock.Unlock()
buf := make([]byte, math.MaxUint16)
for {
n, from, err := r.stunConn.ReadFrom(buf)
if err != nil {
log.Errorf("Failed to read from stun conn. Exiting loop %v", err)
break
}
_, err = r.turnClient.HandleInbound(buf[:n], from)
if err != nil {
log.Errorf("Failed to handle inbound turn message: %s. Exiting loop", err)
break
}
}
}()
}
func toURL(uri *stun.URI) string {
return fmt.Sprintf("%s:%d", uri.Host, uri.Port)
}

View File

@@ -0,0 +1,36 @@
package relay
import (
"os"
"testing"
"github.com/pion/stun/v2"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestNewPermanentTurn(t *testing.T) {
turnURI, err := stun.ParseURI("turns:turn.netbird.io:443?transport=tcp")
if err != nil {
t.Errorf("failed to parse stun url: %v", err)
}
turnURI.Username = "1713006060"
turnURI.Password = "pO5Pfx15luZ92mW+FHPa6/LtJ7Y="
stunURI, err := stun.ParseURI("stun:stun.netbird.io:5555")
if err != nil {
t.Errorf("failed to parse stun url: %v", err)
}
turnRelay := NewPermanentTurn(stunURI, turnURI)
err = turnRelay.Open()
if err != nil {
t.Errorf("failed to open turn relay: %v", err)
}
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -69,10 +68,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -104,15 +99,11 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
m.ctx = nil
}
@@ -219,11 +210,9 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
}
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported

View File

@@ -4,14 +4,14 @@ package routemanager
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -32,31 +32,19 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
type ruleParams struct {
priority int
fwmark int
tableID int
family int
priority int
invert bool
suppressPrefix int
description string
@@ -64,10 +52,10 @@ type ruleParams struct {
func getSetupRules() []ruleParams {
return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
}
}
@@ -81,6 +69,8 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
log.Infof("Using legacy routing setup")
@@ -91,13 +81,6 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
@@ -140,17 +123,11 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeRule(rule); err != nil {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
}
@@ -167,10 +144,6 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
@@ -363,8 +336,22 @@ func flushRoutes(tableID, family int) error {
}
func enableIPForwarding() error {
_, err := setSysctl(ipv4ForwardingPath, 1, false)
return err
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
@@ -442,7 +429,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -459,13 +446,43 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleDel(rule); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
func removeAllRules(params ruleParams) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
@@ -492,83 +509,3 @@ func getAddressFamily(prefix netip.Prefix) int {
}
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, wgInterface)
_, _, err = setupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())

View File

@@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
psCmd, addressFamily, destinationPrefix,
)

View File

@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}

View File

@@ -76,7 +76,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
continue
}
_, err = p.remoteConn.Write(buf[:n])
log.Debugf("read from local conn %d bytes and forward to relay", n)
n, err = p.remoteConn.Write(buf[:n])
if err != nil {
continue
}

6
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -172,8 +172,8 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2023
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e

5
go.sum
View File

@@ -383,14 +383,15 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=

View File

@@ -13,14 +13,6 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
@@ -28,6 +20,8 @@ type ICEBind struct {
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
receiverCreator *receiverCreator
}
func NewICEBind(transportNet transport.Net) *ICEBind {
@@ -35,9 +29,9 @@ func NewICEBind(transportNet transport.Net) *ICEBind {
transportNet: transportNet,
}
rc := receiverCreator{
ib,
}
rc := newReceiverCreator(ib)
ib.receiverCreator = rc
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
@@ -53,7 +47,11 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
func (s *ICEBind) SetTurnConn(conn interface{}) {
s.receiverCreator.setTurnConn(conn)
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn, netConn net.PacketConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()

View File

@@ -0,0 +1,38 @@
package bind
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
relayConn net.PacketConn
}
func newReceiverCreator(iceBind *ICEBind) *receiverCreator {
return &receiverCreator{
iceBind: iceBind,
}
}
func (rc *receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn, nil)
}
func (rc *receiverCreator) CreateRelayReceiverFn(msgPool *sync.Pool) wgConn.ReceiveFunc {
if rc.relayConn == nil {
log.Debugf("-------rc.conn is nil")
return nil
}
return rc.iceBind.createIPv4ReceiverFn(msgPool, nil, nil, rc.relayConn)
}
func (rc *receiverCreator) setTurnConn(relayConn interface{}) {
log.Debug("------ SET TURN CONN")
rc.relayConn = relayConn.(net.PacketConn)
}

View File

@@ -150,3 +150,10 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
}
func (w *WGIface) SetRelayConn(conn interface{}) {
w.mu.Lock()
defer w.mu.Unlock()
w.tun.SetTurnConn(conn)
}

View File

@@ -85,23 +85,27 @@ func tunModuleIsLoaded() bool {
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
/*
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true
}
if canCreateFakeWireGuardInterface() {
return true
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
return loaded
return loaded
*/
}
func canCreateFakeWireGuardInterface() bool {

View File

@@ -15,4 +15,5 @@ type wgTunDevice interface {
DeviceName() string
Close() error
Wrapper() *DeviceWrapper // todo eliminate this function
SetTurnConn(conn interface{})
}

View File

@@ -31,6 +31,11 @@ type tunKernelDevice struct {
udpMux *bind.UniversalUDPMuxDefault
}
func (t *tunKernelDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{

View File

@@ -30,6 +30,11 @@ type tunNetstackDevice struct {
configurer wgConfigurer
}
func (t *tunNetstackDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice {
return &tunNetstackDevice{
name: name,

View File

@@ -54,7 +54,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "),
device.NewLogger(device.LogLevelError, "[netbird] "),
)
err = t.assignAddr()
@@ -70,6 +70,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.configurer.close()
return nil, err
}
log.Debugf("configuration done")
return t.configurer, nil
}
@@ -125,6 +126,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
return t.wrapper
}
func (t *tunUSPDevice) SetTurnConn(conn interface{}) {
t.iceBind.SetTurnConn(conn)
err := t.device.BindUpdate()
if err != nil {
log.Errorf("failed to update bind: %v", err)
}
}
// assignAddr Adds IP address to the tunnel interface
func (t *tunUSPDevice) assignAddr() error {
link := newWGLink(t.name)

View File

@@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
type wgKernelConfigurer struct {
@@ -29,7 +31,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil {
return err
}
fwmark := getFwmark()
fwmark := nbnet.NetbirdFwmark
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,

View File

@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
if runtime.GOOS == "linux" {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -251,7 +251,7 @@ var (
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

View File

@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
@@ -1849,7 +1849,6 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
}
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
log.Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err)
@@ -1862,10 +1861,9 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
AccountID: account.Id,
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1909,7 +1907,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID, accountID)
users[userID] = NewOwnerUser(userID)
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}

View File

@@ -11,134 +11,133 @@ type Code struct {
Code string
}
// Existing consts must not be changed, as this will break the compatibility with the existing data
const (
// PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = 0
PeerAddedByUser Activity = iota
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
PeerAddedWithSetupKey Activity = 1
PeerAddedWithSetupKey
// UserJoined indicates that a new user joined the account
UserJoined Activity = 2
UserJoined
// UserInvited indicates that a new user was invited to join the account
UserInvited Activity = 3
UserInvited
// AccountCreated indicates that a new account has been created
AccountCreated Activity = 4
AccountCreated
// PeerRemovedByUser indicates that a user removed a peer from the system
PeerRemovedByUser Activity = 5
PeerRemovedByUser
// RuleAdded indicates that a user added a new rule
RuleAdded Activity = 6
RuleAdded
// RuleUpdated indicates that a user updated a rule
RuleUpdated Activity = 7
RuleUpdated
// RuleRemoved indicates that a user removed a rule
RuleRemoved Activity = 8
RuleRemoved
// PolicyAdded indicates that a user added a new policy
PolicyAdded Activity = 9
PolicyAdded
// PolicyUpdated indicates that a user updated a policy
PolicyUpdated Activity = 10
PolicyUpdated
// PolicyRemoved indicates that a user removed a policy
PolicyRemoved Activity = 11
PolicyRemoved
// SetupKeyCreated indicates that a user created a new setup key
SetupKeyCreated Activity = 12
SetupKeyCreated
// SetupKeyUpdated indicates that a user updated a setup key
SetupKeyUpdated Activity = 13
SetupKeyUpdated
// SetupKeyRevoked indicates that a user revoked a setup key
SetupKeyRevoked Activity = 14
SetupKeyRevoked
// SetupKeyOverused indicates that setup key usage exhausted
SetupKeyOverused Activity = 15
SetupKeyOverused
// GroupCreated indicates that a user created a group
GroupCreated Activity = 16
GroupCreated
// GroupUpdated indicates that a user updated a group
GroupUpdated Activity = 17
GroupUpdated
// GroupAddedToPeer indicates that a user added group to a peer
GroupAddedToPeer Activity = 18
GroupAddedToPeer
// GroupRemovedFromPeer indicates that a user removed peer group
GroupRemovedFromPeer Activity = 19
GroupRemovedFromPeer
// GroupAddedToUser indicates that a user added group to a user
GroupAddedToUser Activity = 20
GroupAddedToUser
// GroupRemovedFromUser indicates that a user removed a group from a user
GroupRemovedFromUser Activity = 21
GroupRemovedFromUser
// UserRoleUpdated indicates that a user changed the role of a user
UserRoleUpdated Activity = 22
UserRoleUpdated
// GroupAddedToSetupKey indicates that a user added group to a setup key
GroupAddedToSetupKey Activity = 23
GroupAddedToSetupKey
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
GroupRemovedFromSetupKey Activity = 24
GroupRemovedFromSetupKey
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
GroupAddedToDisabledManagementGroups Activity = 25
GroupAddedToDisabledManagementGroups
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
GroupRemovedFromDisabledManagementGroups Activity = 26
GroupRemovedFromDisabledManagementGroups
// RouteCreated indicates that a user created a route
RouteCreated Activity = 27
RouteCreated
// RouteRemoved indicates that a user deleted a route
RouteRemoved Activity = 28
RouteRemoved
// RouteUpdated indicates that a user updated a route
RouteUpdated Activity = 29
RouteUpdated
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
PeerSSHEnabled Activity = 30
PeerSSHEnabled
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
PeerSSHDisabled Activity = 31
PeerSSHDisabled
// PeerRenamed indicates that a user renamed a peer
PeerRenamed Activity = 32
PeerRenamed
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
PeerLoginExpirationEnabled Activity = 33
PeerLoginExpirationEnabled
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
PeerLoginExpirationDisabled Activity = 34
PeerLoginExpirationDisabled
// NameserverGroupCreated indicates that a user created a nameservers group
NameserverGroupCreated Activity = 35
NameserverGroupCreated
// NameserverGroupDeleted indicates that a user deleted a nameservers group
NameserverGroupDeleted Activity = 36
NameserverGroupDeleted
// NameserverGroupUpdated indicates that a user updated a nameservers group
NameserverGroupUpdated Activity = 37
NameserverGroupUpdated
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
AccountPeerLoginExpirationEnabled Activity = 38
AccountPeerLoginExpirationEnabled
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
AccountPeerLoginExpirationDisabled Activity = 39
AccountPeerLoginExpirationDisabled
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated Activity = 40
AccountPeerLoginExpirationDurationUpdated
// PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated Activity = 41
PersonalAccessTokenCreated
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted Activity = 42
PersonalAccessTokenDeleted
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated Activity = 43
ServiceUserCreated
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted Activity = 44
ServiceUserDeleted
// UserBlocked indicates that a user blocked another user
UserBlocked Activity = 45
UserBlocked
// UserUnblocked indicates that a user unblocked another user
UserUnblocked Activity = 46
UserUnblocked
// UserDeleted indicates that a user deleted another user
UserDeleted Activity = 47
UserDeleted
// GroupDeleted indicates that a user deleted group
GroupDeleted Activity = 48
GroupDeleted
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer Activity = 49
UserLoggedInPeer
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired Activity = 50
PeerLoginExpired
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin Activity = 51
DashboardLogin
// IntegrationCreated indicates that the user created an integration
IntegrationCreated Activity = 52
IntegrationCreated
// IntegrationUpdated indicates that the user updated an integration
IntegrationUpdated Activity = 53
IntegrationUpdated
// IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted Activity = 54
IntegrationDeleted
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled Activity = 55
AccountPeerApprovalEnabled
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled Activity = 56
AccountPeerApprovalDisabled
// PeerApproved indicates that the peer has been approved
PeerApproved Activity = 57
PeerApproved
// PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked Activity = 58
PeerApprovalRevoked
// TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole Activity = 59
TransferredOwnerRole
// PostureCheckCreated indicates that the user created a posture check
PostureCheckCreated Activity = 60
PostureCheckCreated
// PostureCheckUpdated indicates that the user updated a posture check
PostureCheckUpdated Activity = 61
PostureCheckUpdated
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted Activity = 62
PostureCheckDeleted
)
var activityMap = map[Activity]Code{

View File

@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user", "account_id")
adminUser := server.NewAdminUser("test_user")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }

View File

@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
testDNSSettingsUserID: server.NewAdminUser("test_user"),
},
DNSSettings: baseExistingDNSSettings,
}

View File

@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
},
}
accountID := "test_account"
adminUser := server.NewAdminUser("test_user", "account_id")
adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)

View File

@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user", "account_id")
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user", "account_id")
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user", "account_id")
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt {
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user", "account_id")
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt {

View File

@@ -12,7 +12,6 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -39,7 +38,7 @@ type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -76,7 +75,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

View File

@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user", "account_id"),
"test_user": server.NewAdminUser("test_user"),
},
}

View File

@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user", "account_id")
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
return nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user", "account_id")
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user", "account_id")
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user", "account_id"),
"test_user": server.NewAdminUser("test_user"),
},
}

View File

@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user", "account_id")
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true)

View File

@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
case <-ticker.C:
select {
case <-cancel:
log.Tracef("scheduled job %s was canceled, stop timer", ID)
log.Debugf("scheduled job %s was canceled, stop timer", ID)
ticker.Stop()
return
default:
log.Tracef("time to do a scheduled job %s", ID)
log.Debugf("time to do a scheduled job %s", ID)
}
runIn, reschedule := job()
if !reschedule {
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
log.Tracef("job %s is not scheduled to run again", ID)
log.Debugf("job %s is not scheduled to run again", ID)
ticker.Stop()
return
}
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
ticker.Reset(runIn)
}
case <-cancel:
log.Tracef("job %s was canceled, stopping timer", ID)
log.Debugf("job %s was canceled, stopping timer", ID)
ticker.Stop()
return
}

View File

@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
@@ -135,152 +134,20 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock
}
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
// Get the reflect.Value of the records slice
v := reflect.ValueOf(records)
if v.Kind() != reflect.Slice {
return fmt.Errorf("provided input is not a slice")
}
// Insert records in batches
for i := 0; i < v.Len(); i += batchSize {
end := i + batchSize
if end > v.Len() {
end = v.Len()
}
// Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err
}
}
return nil
}
func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now()
// operate over a fresh copy as we will modify its fields
accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
//we need an explicit reference to the account for gorm
key.AccountID = accCopy.Id
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
}
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
for id, peer := range account.Peers {
peer.ID = id
//we need an explicit reference to the account for gorm
peer.AccountID = accCopy.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
account.PeersG = append(account.PeersG, *peer)
}
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id
//we need an explicit reference to the account for gorm
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
accCopy.UsersG = append(accCopy.UsersG, *user)
}
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id
//we need an explicit reference to the account for gorm
group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
}
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id
//we need an explicit reference to the account for gorm
route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
}
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id
//we need an explicit reference to the account for gorm
ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
Create(accCopy)
if result.Error != nil {
return result.Error
}
const batchSize = 500
err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
if err != nil {
return err
}
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err
}
func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now()
account.UsersG = make([]User, 0, len(account.Users))
for id, user := range account.Users {
user.Id = id
//we need an explicit reference to an account as it is missing for some reason
user.AccountID = account.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
@@ -288,6 +155,58 @@ func (s *SqliteStore) DeleteAccount(account *Account) error {
account.UsersG = append(account.UsersG, *user)
}
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
}
for id, route := range account.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
return err
}
func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now()
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {

View File

@@ -2,12 +2,7 @@ package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
@@ -34,141 +29,6 @@ func TestSqlite_NewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
@@ -188,12 +48,6 @@ func TestSqlite_SaveAccount(t *testing.T) {
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
admin := account.Users["testuser"]
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
HashedToken: "hashed token",
}}
err := store.SaveAccount(account)
require.NoError(t, err)
@@ -256,7 +110,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
store := newSqliteStore(t)
testUserID := "testuser"
user := NewAdminUser(testUserID, "account_id")
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
@@ -539,12 +393,3 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account)
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}

View File

@@ -180,11 +180,9 @@ func (u *User) Copy() *User {
}
// NewUser creates a new user
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
accountID string) *User {
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
return &User{
Id: ID,
AccountID: accountID,
Id: id,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
@@ -196,26 +194,22 @@ func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, se
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(ID, accountID string) *User {
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
accountID)
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(ID, accountID string) *User {
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
accountID)
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(ID, accountID string) *User {
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
accountID)
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -237,7 +231,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser

View File

@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
err := store.SaveAccount(account)
if err != nil {
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID, mockAccountID)
user := NewAdminUser(mockUserID)
assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID, mockAccountID)
user = NewRegularUser(mockUserID)
assert.False(t, user.HasAdminPower())
}
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(account)
if err != nil {

View File

@@ -56,7 +56,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
// MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
rosenpassPubKey []byte, rosenpassAddr, relayedAddress, serverRefIP string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
@@ -69,6 +69,10 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
Relay: &proto.Relay{
RelayedAddress: relayedAddress,
SrvRefAddress: serverRefIP,
},
},
}, nil
}

View File

@@ -215,16 +215,21 @@ type Body struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
// wgListenPort is an actual WireGuard listen port
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
Mode *Mode `protobuf:"bytes,5,opt,name=mode,proto3" json:"mode,omitempty"`
// featuresSupported list of supported features by the client of this protocol
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
// is this optional or mandatory?
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
Relay *Relay `protobuf:"bytes,8,opt,name=relay,proto3" json:"relay,omitempty"`
}
func (x *Body) Reset() {
@@ -308,13 +313,18 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
return nil
}
func (x *Body) GetRelay() *Relay {
if x != nil {
return x.Relay
}
return nil
}
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Direct *bool `protobuf:"varint,1,opt,name=direct,proto3,oneof" json:"direct,omitempty"`
}
func (x *Mode) Reset() {
@@ -349,11 +359,59 @@ func (*Mode) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{3}
}
func (x *Mode) GetDirect() bool {
if x != nil && x.Direct != nil {
return *x.Direct
type Relay struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RelayedAddress string `protobuf:"bytes,1,opt,name=relayedAddress,proto3" json:"relayedAddress,omitempty"`
SrvRefAddress string `protobuf:"bytes,2,opt,name=srvRefAddress,proto3" json:"srvRefAddress,omitempty"`
}
func (x *Relay) Reset() {
*x = Relay{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
return false
}
func (x *Relay) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Relay) ProtoMessage() {}
func (x *Relay) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Relay.ProtoReflect.Descriptor instead.
func (*Relay) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
}
func (x *Relay) GetRelayedAddress() string {
if x != nil {
return x.RelayedAddress
}
return ""
}
func (x *Relay) GetSrvRefAddress() string {
if x != nil {
return x.SrvRefAddress
}
return ""
}
type RosenpassConfig struct {
@@ -369,7 +427,7 @@ type RosenpassConfig struct {
func (x *RosenpassConfig) Reset() {
*x = RosenpassConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
mi := &file_signalexchange_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -382,7 +440,7 @@ func (x *RosenpassConfig) String() string {
func (*RosenpassConfig) ProtoMessage() {}
func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
mi := &file_signalexchange_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -395,7 +453,7 @@ func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use RosenpassConfig.ProtoReflect.Descriptor instead.
func (*RosenpassConfig) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
return file_signalexchange_proto_rawDescGZIP(), []int{5}
}
func (x *RosenpassConfig) GetRosenpassPubKey() []byte {
@@ -431,7 +489,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -451,33 +509,39 @@ var file_signalexchange_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d,
0x0a, 0x0f, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69,
0x67, 0x12, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75,
0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01,
0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67,
0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59,
0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12,
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2b, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18,
0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x52, 0x05, 0x72, 0x65,
0x6c, 0x61, 0x79, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f,
0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52,
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10,
0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x06, 0x0a, 0x04, 0x4d,
0x6f, 0x64, 0x65, 0x22, 0x55, 0x0a, 0x05, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x26, 0x0a, 0x0e,
0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64,
0x72, 0x65, 0x73, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x72, 0x76, 0x52, 0x65, 0x66, 0x41, 0x64,
0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x73, 0x72, 0x76,
0x52, 0x65, 0x66, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x6f,
0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a,
0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e,
0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04,
0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f,
0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63,
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e,
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -493,29 +557,31 @@ func file_signalexchange_proto_rawDescGZIP() []byte {
}
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_signalexchange_proto_goTypes = []interface{}{
(Body_Type)(0), // 0: signalexchange.Body.Type
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
(*Message)(nil), // 2: signalexchange.Message
(*Body)(nil), // 3: signalexchange.Body
(*Mode)(nil), // 4: signalexchange.Mode
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
(*Relay)(nil), // 5: signalexchange.Relay
(*RosenpassConfig)(nil), // 6: signalexchange.RosenpassConfig
}
var file_signalexchange_proto_depIdxs = []int32{
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
0, // 1: signalexchange.Body.type:type_name -> signalexchange.Body.Type
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
6, // [6:8] is the sub-list for method output_type
4, // [4:6] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
6, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
5, // 4: signalexchange.Body.relay:type_name -> signalexchange.Relay
1, // 5: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 8: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
7, // [7:9] is the sub-list for method output_type
5, // [5:7] is the sub-list for method input_type
5, // [5:5] is the sub-list for extension type_name
5, // [5:5] is the sub-list for extension extendee
0, // [0:5] is the sub-list for field type_name
}
func init() { file_signalexchange_proto_init() }
@@ -573,6 +639,18 @@ func file_signalexchange_proto_init() {
}
}
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Relay); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RosenpassConfig); i {
case 0:
return &v.state
@@ -585,14 +663,13 @@ func file_signalexchange_proto_init() {
}
}
}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_signalexchange_proto_rawDesc,
NumEnums: 1,
NumMessages: 5,
NumMessages: 6,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -49,22 +49,33 @@ message Body {
MODE = 4;
}
Type type = 1;
// these will be set in OFFER, ANSWER, CANDIDATE only
string payload = 2;
// wgListenPort is an actual WireGuard listen port
// these will be set in OFFER, ANSWER, CANDIDATE only
uint32 wgListenPort = 3;
// these will be set in OFFER, ANSWER, CANDIDATE only
string netBirdVersion = 4;
Mode mode = 5;
// featuresSupported list of supported features by the client of this protocol
repeated uint32 featuresSupported = 6;
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
// is this optional or mandatory?
RosenpassConfig rosenpassConfig = 7;
Relay relay = 8;
}
// Mode indicates a connection mode
message Mode {
optional bool direct = 1;
}
message Relay {
string relayedAddress = 1;
string srvRefAddress = 2;
}
message RosenpassConfig {

View File

@@ -49,10 +49,6 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
@@ -127,10 +123,6 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
@@ -151,10 +143,6 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr

View File

@@ -8,7 +8,6 @@ import (
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
@@ -53,10 +52,6 @@ func RemoveListenerHooks() {
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
@@ -149,11 +144,7 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)

View File

@@ -1,16 +1,10 @@
package net
import (
"os"
"github.com/google/uuid"
)
import "github.com/google/uuid"
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// ConnectionID provides a globally unique identifier for network connections.
@@ -21,7 +15,3 @@ type ConnectionID string
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = SetSocketOpt(int(fd))
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
})
if err != nil {
return fmt.Errorf("control: %w", err)
@@ -33,11 +33,3 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil
}
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}