Compare commits

..

29 Commits

Author SHA1 Message Date
Misha Bragin
515ce9e3af Update management/server/sqlite_store.go 2024-04-17 20:55:32 +02:00
Misha Bragin
89383b7f01 Update management/server/sqlite_store.go 2024-04-17 20:55:01 +02:00
Misha Bragin
db34162733 Update management/server/sqlite_store.go 2024-04-17 20:54:14 +02:00
Misha Bragin
bd761e2177 Update management/server/sqlite_store.go 2024-04-17 20:53:32 +02:00
Misha Bragin
4e1b95a4c6 Update management/server/sqlite_store.go 2024-04-17 20:53:24 +02:00
Misha Bragin
05993af7bf Update management/server/sqlite_store.go 2024-04-17 20:53:11 +02:00
braginini
9d1cb00570 Fix setup keys test 2024-04-17 20:27:55 +02:00
braginini
543731df45 Fix setup keys test 2024-04-17 19:58:24 +02:00
braginini
e6628ec231 Fix setup keys 2024-04-17 19:48:09 +02:00
braginini
41d4dd2aff reduce log level of scheduler to trace 2024-04-17 19:34:59 +02:00
braginini
30bed57711 Fix account deletion 2024-04-17 19:12:53 +02:00
braginini
6960b68322 Add pats to test save account 2024-04-17 19:07:17 +02:00
braginini
3b3aa18148 Store setup keys and ns groups in a batch 2024-04-17 18:32:13 +02:00
braginini
93045f3e3a Fix rand lint issue 2024-04-17 18:07:02 +02:00
braginini
fd3c1dea8e Add save large account test 2024-04-17 18:02:10 +02:00
braginini
48aff7a26e Fix test compilation errors 2024-04-17 17:39:28 +02:00
braginini
83dfe8e3a3 Fix test compilation errors 2024-04-17 17:27:23 +02:00
braginini
38e10af2d9 Add accountID reference 2024-04-17 17:16:56 +02:00
braginini
99854a126a Add comments 2024-04-17 17:08:01 +02:00
braginini
a75f982fcd Copy account when storing to avoid reference issues 2024-04-17 17:03:21 +02:00
braginini
e7a6483912 Optimize all other objects storing in SQLite 2024-04-17 12:35:41 +02:00
braginini
30ede299b8 Optimize peer storing in SQLite 2024-04-17 11:50:33 +02:00
Viktor Liu
e3b76448f3 Fix ICE endpoint remote port in status command (#1851) 2024-04-16 14:01:59 +02:00
Viktor Liu
e0de86d6c9 Use fixed activity codes (#1846)
* Add duplicate constants check
2024-04-15 14:15:46 +02:00
Zoltan Papp
5204d07811 Pass integrated validator for API (#1814)
Pass integrated validator for API handler
2024-04-15 12:08:38 +02:00
Viktor Liu
5ea24ba56e Add sysctl opts to prevent reverse path filtering from dropping fwmark packets (#1839) 2024-04-12 17:53:07 +02:00
Viktor Liu
d30cf8706a Allow disabling custom routing (#1840) 2024-04-12 16:53:11 +02:00
Viktor Liu
15a2feb723 Use fixed preference for rules (#1836) 2024-04-12 16:07:03 +02:00
Viktor Liu
91b2f9fc51 Use route active store (#1834) 2024-04-12 15:22:40 +02:00
52 changed files with 765 additions and 1120 deletions

View File

@@ -33,6 +33,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v4
with:

View File

@@ -143,7 +143,7 @@ func (m *Manager) AllowNetbird() error {
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %#v", rule)
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}

View File

@@ -138,7 +138,6 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
turnRelay *relay.PermanentTurn
}
// Peer is an instance of the Connection Peer
@@ -200,7 +199,7 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
wgProxyFactory: &wgproxy.Factory{},
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
@@ -453,19 +452,10 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
t = sProto.Body_OFFER
}
msg, err := signal.MarshalCredential(
myKey,
offerAnswer.WgListenPort,
remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
t,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelayedAddr.String(),
offerAnswer.RemoteAddr.String(),
)
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
if err != nil {
return err
}
@@ -493,14 +483,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
turnRelay := relay.NewPermanentTurn(e.STUNs[0], e.TURNs[0])
err = turnRelay.Open()
if err != nil {
return fmt.Errorf("faile to open turn relay: %w", err)
}
e.turnRelay = turnRelay
e.wgInterface.SetRelayConn(e.turnRelay.RelayConn())
// todo update signal
}
@@ -639,7 +621,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
var newTURNs []*stun.URI
log.Debugf("got TURNs update from Management Service, updating")
for _, turn := range turns {
log.Debugf("-----updated Turn %v, %s, %s", turn.HostConfig.Uri, turn.User, turn.Password)
url, err := stun.ParseURI(turn.HostConfig.Uri)
if err != nil {
return err
@@ -649,6 +630,7 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
newTURNs = append(newTURNs, url)
}
e.TURNs = newTURNs
return nil
}
@@ -952,7 +934,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
RosenpassAddr: e.getRosenpassAddr(),
}
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover, e.turnRelay)
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
if err != nil {
return nil, err
}
@@ -1018,17 +1000,6 @@ func (e *Engine) receiveSignalEvents() {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
if err != nil {
return err
}
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -1038,8 +1009,6 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelayedAddr: relayedAddr,
RemoteAddr: remoteAddr,
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1055,17 +1024,6 @@ func (e *Engine) receiveSignalEvents() {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
if err != nil {
return err
}
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
if err != nil {
return err
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -1075,8 +1033,6 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelayedAddr: relayedAddr,
RemoteAddr: remoteAddr,
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1087,6 +1043,7 @@ func (e *Engine) receiveSignalEvents() {
conn.OnRemoteCandidate(candidate)
case sProto.Body_MODE:
}
return nil
})
if err != nil {
@@ -1158,8 +1115,6 @@ func (e *Engine) close() {
log.Errorf("failed closing ebpf proxy: %s", err)
}
e.turnRelay.Close()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
if e.dnsServer != nil {
e.dnsServer.Stop()

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"time"
@@ -13,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
@@ -93,10 +93,6 @@ type OfferAnswer struct {
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
// This value is the local Rosenpass server address when sending the message
RosenpassAddr string
// Turn Relay
RelayedAddr net.Addr
RemoteAddr net.Addr
}
// IceCredentials ICE protocol credentials struct
@@ -145,11 +141,11 @@ type Conn struct {
sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
turnRelay *relay.PermanentTurn
}
// meta holds meta information about a connection
@@ -180,7 +176,7 @@ func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, turnRelay *relay.PermanentTurn) (*Conn, error) {
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
return &Conn{
config: config,
mu: sync.Mutex{},
@@ -193,7 +189,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
wgProxyFactory: wgProxyFactory,
adapter: adapter,
iFaceDiscover: iFaceDiscover,
turnRelay: turnRelay,
}, nil
}
@@ -217,7 +212,7 @@ func (conn *Conn) reCreateAgent() error {
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: conn.config.StunTurn,
CandidateTypes: []ice.CandidateType{},
CandidateTypes: conn.candidateTypes(),
FailedTimeout: &failedTimeout,
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
UDPMux: conn.config.UDPMux,
@@ -267,6 +262,17 @@ func (conn *Conn) reCreateAgent() error {
return nil
}
func (conn *Conn) candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
// Open opens connection to the remote peer starting ICE candidate gathering process.
// Blocks until connection has been closed or connection timeout.
// ConnStatus will be set accordingly
@@ -345,53 +351,42 @@ func (conn *Conn) Open() error {
log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
}
isControlling := conn.config.LocalKey < conn.config.Key
if isControlling {
log.Debugf("---- use this peer's tunr connection")
err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr)
if err != nil {
log.Errorf("failed to punch hole: %v", err)
}
addr, ok := remoteOfferAnswer.RemoteAddr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("failed to cast addr to udp addr")
}
addr.Port = remoteOfferAnswer.WgListenPort
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
} else {
log.Debugf("---- use remote peer tunr connection")
addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("failed to cast addr to udp addr")
}
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
// the ice connection has been established successfully so we are ready to start the proxy
/*
remoteAddr, err := conn.configureConnection(remoteOfferAnswer.RelayedAddr, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
if err != nil {
return err
}
*/
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, addr.String())
err = conn.agent.GatherCandidates()
if err != nil {
return err
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
isControlling := conn.config.LocalKey > conn.config.Key
var remoteConn *ice.Conn
if isControlling {
remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
if err != nil {
return err
}
// dynamically set remote WireGuard port is other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
if err != nil {
return err
}
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
select {
case <-conn.closeCh:
@@ -420,8 +415,25 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
conn.mu.Lock()
defer conn.mu.Unlock()
pair, err := conn.agent.GetSelectedCandidatePair()
if err != nil {
return nil, err
}
var endpoint net.Addr
endpoint = remoteConn.RemoteAddr()
if isRelayCandidate(pair.Local) {
log.Debugf("setup relay connection")
conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err
}
} else {
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
go conn.punchRemoteWGPort(pair, remoteWgPort)
endpoint = remoteConn.RemoteAddr()
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
@@ -433,7 +445,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
}
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
@@ -442,33 +454,31 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
conn.status = StatusConnected
/*
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
rosenpassEnabled := false
if remoteRosenpassPubKey != nil {
rosenpassEnabled = true
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
peerState := State{
PubKey: conn.config.Key,
ConnStatus: conn.status,
ConnStatusUpdate: time.Now(),
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local),
RosenpassEnabled: rosenpassEnabled,
Mux: new(sync.RWMutex),
}
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
peerState.Relayed = true
}
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
*/
err = conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
log.Warnf("unable to save peer's state, got error: %v", err)
}
_, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
if err != nil {
@@ -670,8 +680,6 @@ func (conn *Conn) sendAnswer() error {
Version: version.NetbirdVersion(),
RosenpassPubKey: conn.config.RosenpassPubKey,
RosenpassAddr: conn.config.RosenpassAddr,
RelayedAddr: conn.turnRelay.RelayedAddress(),
RemoteAddr: conn.turnRelay.SrvRefAddr(),
})
if err != nil {
return err
@@ -695,8 +703,6 @@ func (conn *Conn) sendOffer() error {
Version: version.NetbirdVersion(),
RosenpassPubKey: conn.config.RosenpassPubKey,
RosenpassAddr: conn.config.RosenpassAddr,
RelayedAddr: conn.turnRelay.RelayedAddress(),
RemoteAddr: conn.turnRelay.SrvRefAddr(),
})
if err != nil {
return err
@@ -736,10 +742,6 @@ func (conn *Conn) Status() ConnStatus {
return conn.status
}
func (conn *Conn) OnRemoteRelayRequest(relayedAddr string, remoteIP string) {
}
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {

View File

@@ -1,143 +0,0 @@
package relay
import (
"fmt"
"io"
"net"
"os"
"time"
log "github.com/sirupsen/logrus"
)
const (
bufferSize = 800
testFile = "/tmp/1MB"
)
type Speed struct {
}
func NewSpeed() *Speed {
return &Speed{}
}
func (s *Speed) ReceiveFileFromAddr(remoteAddr net.Addr) error {
pc, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
log.Errorf("failed to lisen: %s", err.Error())
return err
}
defer pc.Close()
log.Debugf("--- sending initial message to: %s", remoteAddr.String())
_, err = pc.WriteTo([]byte("hey, I am the receiver"), remoteAddr)
if err != nil {
log.Errorf("failed to send initial msg: %s", err.Error())
return err
}
return s.receiveFile(pc)
}
func (s *Speed) ReceiveFileFromPC(pc net.PacketConn) error {
return s.receiveFile(pc)
}
func (s *Speed) receiveFile(pc net.PacketConn) error {
log.Debugf("--- start to receive file...")
file, err := os.OpenFile(fmt.Sprintf("%s.cp", testFile), os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Errorf("failed to open file: %s", err.Error())
return err
}
_ = file.Truncate(0)
defer file.Close()
buffer := make([]byte, bufferSize)
for {
n, addr, err := pc.ReadFrom(buffer)
if err != nil {
log.Errorf("failed to read from connection: %s", err.Error())
return err
}
n, err = file.Write(buffer[:n])
if err != nil {
log.Errorf("failed to write to file: %s", err.Error())
return err
}
_, err = pc.WriteTo([]byte("ack"), addr)
if err != nil {
log.Errorf("failed to send ack: %s", err.Error())
}
log.Debugf("received %d bytes from %s", n, addr)
}
}
func (s *Speed) SendFileToPC(relayConn net.PacketConn) error {
buf := make([]byte, bufferSize)
log.Debugf("--- wait for initial message")
n, rAddr, err := relayConn.ReadFrom(buf)
if err != nil {
log.Errorf("failed to read from connection: %s", err.Error())
return err
}
log.Errorf("received initial msg %d bytes (%s), addr %s", n, string(buf[:n]), rAddr.String())
return s.sendFile(relayConn, rAddr)
}
func (s *Speed) SendFileToAddr(addr net.Addr) error {
pc, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
log.Errorf("failed to lisen: %s", err.Error())
return err
}
defer pc.Close()
return s.sendFile(pc, addr)
}
func (s *Speed) sendFile(conn net.PacketConn, rAddr net.Addr) error {
log.Debugf("--- start to send file...")
file, err := os.Open(testFile)
if err != nil {
// Handle error
return nil
}
defer file.Close()
buf := make([]byte, bufferSize)
start := time.Now()
sent := 0
for {
n, err := file.Read(buf)
if err != nil && err != io.EOF {
return err
}
if n == 0 {
break
}
n, err = conn.WriteTo(buf[:n], rAddr)
if err != nil {
log.Errorf("failed to write to connection: %s", err.Error())
return err
}
sent += n
log.Debugf("sent %d bytes, (%d) to %s", n, sent, rAddr.String())
// wait for ack
_, _, err = conn.ReadFrom(make([]byte, bufferSize))
if err != nil {
log.Errorf("failed to read from connection: %s", err.Error())
return err
}
}
elapsed := time.Since(start)
log.Infof("sent %d bytes, troughtput: %f MB/s", sent, float64(sent)/1024/1024/elapsed.Seconds())
return nil
}

View File

@@ -1,198 +0,0 @@
package relay
import (
"fmt"
"math"
"net"
"sync"
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/turn/v3"
log "github.com/sirupsen/logrus"
)
type PermanentTurn struct {
stunURI *stun.URI
turnURI *stun.URI
stunConn net.PacketConn
turnClient *turn.Client
turnClientListenLock sync.Mutex
relayConn net.PacketConn // represents the remote socket.
srvReflexiveAddress *net.UDPAddr
}
func NewPermanentTurn(stunURL, turnURL *stun.URI) *PermanentTurn {
return &PermanentTurn{
stunURI: stunURL,
turnURI: turnURL,
}
}
func (r *PermanentTurn) Open() error {
stunConn, err := net.ListenPacket("udp4", "0.0.0.0:0")
if err != nil {
return err
}
r.stunConn = stunConn
cfg := &turn.ClientConfig{
STUNServerAddr: toURL(r.stunURI),
TURNServerAddr: toURL(r.turnURI),
Conn: stunConn,
Username: r.turnURI.Username,
Password: r.turnURI.Password,
LoggerFactory: logging.NewDefaultLoggerFactory(),
}
client, err := turn.NewClient(cfg)
if err != nil {
log.Errorf("failed to create turn client: %v", err)
return err
}
r.turnClient = client
err = r.turnClient.Listen()
if err != nil {
log.Errorf("failed to listen turn client: %v", err)
return err
}
relayConn, err := client.Allocate()
if err != nil {
log.Errorf("failed to allocate relay connection: %v", err)
return err
}
r.relayConn = relayConn
srvReflexiveAddress, err := r.discoverPublicIPByStun()
if err != nil {
log.Errorf("failed to discover public IP: %v", err)
return err
}
r.srvReflexiveAddress = srvReflexiveAddress
return nil
}
func (r *PermanentTurn) RelayedAddress() net.Addr {
return r.relayConn.LocalAddr()
}
func (r *PermanentTurn) SrvRefAddr() net.Addr {
return r.srvReflexiveAddress
}
func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error {
/*
err := r.turnClient.CreatePermission(mappedAddr)
if err != nil {
log.Errorf("---- failed to create permission: %v", err)
return err
}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.Fingerprint,
)
if err != nil {
log.Errorf("--- failed to build stun message: %v", err)
return nil
}
_, err = r.relayConn.WriteTo(msg.Raw, mappedAddr)
if err != nil {
log.Errorf("failed to write to relay conn: %v", err)
return err
}
*/
_, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr)
return err
}
func (r *PermanentTurn) RelayConn() net.PacketConn {
return r.relayConn
}
func (r *PermanentTurn) Close() {
r.turnClient.Close()
err := r.relayConn.Close()
if err != nil {
log.Errorf("failed to close relayConn: %s", err.Error())
}
err = r.stunConn.Close()
if err != nil {
log.Errorf("failed to close stunConn: %s", err.Error())
}
}
func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) {
addr, err := r.turnClient.SendBindingRequest()
if err != nil {
log.Errorf("failed to send binding request: %v", err)
return nil, err
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return nil, fmt.Errorf("failed to cast addr to udp addr")
}
return udpAddr, nil
}
func (r *PermanentTurn) discoverPublicIPByStun() (*net.UDPAddr, error) {
c, err := stun.DialURI(r.stunURI, &stun.DialConfig{})
if err != nil {
panic(err)
}
message := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
var addr *net.UDPAddr
err = c.Do(message, func(res stun.Event) {
if res.Error != nil {
panic(res.Error)
}
var xorAddr stun.XORMappedAddress
if err := xorAddr.GetFrom(res.Message); err != nil {
log.Errorf("failed to get xor address: %v", err)
return
}
addr = &net.UDPAddr{
IP: xorAddr.IP,
Port: xorAddr.Port,
}
})
if err != nil {
return nil, err
}
return addr, nil
}
func (r *PermanentTurn) listen() {
if !r.turnClientListenLock.TryLock() {
return
}
go func() {
defer r.turnClientListenLock.Unlock()
buf := make([]byte, math.MaxUint16)
for {
n, from, err := r.stunConn.ReadFrom(buf)
if err != nil {
log.Errorf("Failed to read from stun conn. Exiting loop %v", err)
break
}
_, err = r.turnClient.HandleInbound(buf[:n], from)
if err != nil {
log.Errorf("Failed to handle inbound turn message: %s. Exiting loop", err)
break
}
}
}()
}
func toURL(uri *stun.URI) string {
return fmt.Sprintf("%s:%d", uri.Host, uri.Port)
}

View File

@@ -1,137 +0,0 @@
package relay
import (
"os"
"sync"
"testing"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
const (
userName = "1714092678"
password = "8PEprGKo+UARpYpQOulNz3H24dI="
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestMyTurnUpload(t *testing.T) {
turnURI, err := stun.ParseURI("turn:api.stage.netbird.io:3478?transport=udp")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnURI.Username = userName
turnURI.Password = password
stunURI, err := stun.ParseURI("stun:api.stage.netbird.io:3478")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnRelayA := NewPermanentTurn(stunURI, turnURI)
err = turnRelayA.Open()
if err != nil {
t.Fatalf("failed to open turn relay: %v", err)
}
defer turnRelayA.Close()
turnRelayB := NewPermanentTurn(stunURI, turnURI)
peerBAddr, err := turnRelayB.discoverPublicIPByStun()
if err != nil {
t.Fatalf("failed to discover public ip: %v", err)
}
err = turnRelayA.PunchHole(peerBAddr)
if err != nil {
t.Fatalf("failed to punch hole: %v", err)
}
// at this point, the relayed side should be established
wg := sync.WaitGroup{}
wg.Add(2)
speedB := NewSpeed()
go func() {
err := speedB.ReceiveFileFromAddr(turnRelayA.relayConn.LocalAddr())
if err != nil {
log.Errorf("failed to receive file: %v", err)
}
wg.Done()
}()
speedA := NewSpeed()
go func() {
err := speedA.SendFileToPC(turnRelayA.relayConn)
if err != nil {
log.Errorf("failed to send file: %v", err)
}
log.Debugf("file sent")
wg.Done()
}()
wg.Wait()
}
func TestMyTurnDownload(t *testing.T) {
turnURI, err := stun.ParseURI("turn:api.stage.netbird.io:3478?transport=udp")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnURI.Username = "1714016034"
turnURI.Password = "oDpL6tDu0d+xcO3rQnHoEvbcS/Q="
stunURI, err := stun.ParseURI("stun:api.stage.netbird.io:3478")
if err != nil {
t.Fatalf("failed to parse stun url: %v", err)
}
turnRelayA := NewPermanentTurn(stunURI, turnURI)
err = turnRelayA.Open()
if err != nil {
t.Fatalf("failed to open turn relay: %v", err)
}
defer turnRelayA.Close()
turnRelayB := NewPermanentTurn(stunURI, turnURI)
peerBAddr, err := turnRelayB.discoverPublicIPByStun()
if err != nil {
t.Fatalf("failed to discover public ip: %v", err)
}
err = turnRelayA.PunchHole(peerBAddr)
if err != nil {
t.Fatalf("failed to punch hole: %v", err)
}
// at this point, the relayed side should be established
wg := sync.WaitGroup{}
wg.Add(2)
speedB := NewSpeed()
go func() {
err := speedB.SendFileToAddr(turnRelayA.relayConn.LocalAddr())
if err != nil {
log.Errorf("failed to receive file: %v", err)
}
wg.Done()
}()
speedA := NewSpeed()
go func() {
err := speedA.ReceiveFileFromPC(turnRelayA.relayConn)
if err != nil {
log.Errorf("failed to send file: %v", err)
}
log.Debugf("file sent")
wg.Done()
}()
wg.Wait()
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
// Init sets up the routing
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
if err := cleanupRouting(); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() {
if m.serverRouter != nil {
m.serverRouter.cleanUp()
}
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
if !nbnet.CustomRoutingDisabled() {
if err := cleanupRouting(); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
}
m.ctx = nil
}
@@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
}
func isPrefixSupported(prefix netip.Prefix) bool {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
if !nbnet.CustomRoutingDisabled() {
switch runtime.GOOS {
case "linux", "windows", "darwin":
return true
}
}
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported

View File

@@ -4,14 +4,14 @@ package routemanager
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"strings"
"syscall"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
@@ -32,19 +32,31 @@ const (
rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
ipv4ForwardingPath = "net.ipv4.ip_forward"
rpFilterPath = "net.ipv4.conf.all.rp_filter"
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
)
var ErrTableIDExists = errors.New("ID exists with different name")
var routeManager = &RouteManager{}
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// determines whether to use the legacy routing setup
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
tableID int
family int
priority int
invert bool
suppressPrefix int
description string
@@ -52,10 +64,10 @@ type ruleParams struct {
func getSetupRules() []ruleParams {
return []ruleParams{
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}
}
@@ -69,8 +81,6 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
//
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy {
log.Infof("Using legacy routing setup")
@@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := setupSysctl(wgIface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := cleanupRouting(); cleanErr != nil {
@@ -123,11 +140,17 @@ func cleanupRouting() error {
rules := getSetupRules()
for _, rule := range rules {
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
if err := cleanupSysctl(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return result.ErrorOrNil()
}
@@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
return genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
@@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error {
}
func enableIPForwarding() error {
bytes, err := os.ReadFile(ipv4ForwardingPath)
if err != nil {
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
}
// check if it is already enabled
// see more: https://github.com/netbirdio/netbird/issues/872
if len(bytes) > 0 && bytes[0] == 49 {
return nil
}
//nolint:gosec
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
}
return nil
_, err := setSysctl(ipv4ForwardingPath, 1, false)
return err
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
@@ -429,7 +442,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -446,43 +459,13 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
func removeAllRules(params ruleParams) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
done := make(chan error, 1)
go func() {
for {
if ctx.Err() != nil {
done <- ctx.Err()
return
}
if err := removeRule(params); err != nil {
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
done <- nil
return
}
done <- err
return
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return err
}
}
// addNextHop adds the gateway and device to the route.
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
if addr.IsValid() {
@@ -509,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
}
return netlink.FAMILY_V6
}
// setupSysctl configures sysctl settings for RP filtering and source validation.
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[srcValidMarkPath] = oldVal
}
oldVal, err = setSysctl(rpFilterPath, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[rpFilterPath] = oldVal
}
interfaces, err := net.Interfaces()
if err != nil {
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
}
for _, intf := range interfaces {
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
continue
}
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
oldVal, err := setSysctl(i, 2, true)
if err != nil {
result = multierror.Append(result, err)
} else {
keys[i] = oldVal
}
}
return keys, result.ErrorOrNil()
}
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
}
if currentV == desiredValue || onlyIfOne && currentV != 1 {
return currentV, nil
}
//nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return currentV, nil
}
func cleanupSysctl(originalSettings map[string]int) error {
var result *multierror.Error
for key, value := range originalSettings {
_, err := setSysctl(key, value, false)
if err != nil {
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}

View File

@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
_, _, err = setupRouting(nil, nil)
_, _, err = setupRouting(nil, wgInterface)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cleanupRouting())

View File

@@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s
}
script := fmt.Sprintf(
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
psCmd, addressFamily, destinationPrefix,
)

View File

@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
}
// Set the fwmark on the socket.
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}

4
go.mod
View File

@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -172,7 +172,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2023
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

8
go.sum
View File

@@ -383,14 +383,14 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5 h1:m48qfB2ILlFx3oZlw7aEeD+V6vXnMb0hNwmDCtdcgv0=
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=

View File

@@ -13,6 +13,14 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
@@ -20,8 +28,6 @@ type ICEBind struct {
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
receiverCreator *receiverCreator
}
func NewICEBind(transportNet transport.Net) *ICEBind {
@@ -29,9 +35,9 @@ func NewICEBind(transportNet transport.Net) *ICEBind {
transportNet: transportNet,
}
rc := newReceiverCreator(ib)
ib.receiverCreator = rc
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
@@ -47,22 +53,16 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (s *ICEBind) SetTurnConn(conn interface{}) {
s.receiverCreator.setTurnConn(conn)
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn, netConn net.PacketConn) wgConn.ReceiveFunc {
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if conn != nil {
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
},
)
}
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
@@ -71,40 +71,17 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
}
var numMsgs int
if runtime.GOOS == "linux" {
if netConn != nil {
log.Debugf("----read from turn conn...")
msg := &(*msgs)[0]
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
if err != nil {
return 0, err
}
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
numMsgs = 1
} else {
log.Debugf("----read from pc...")
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
if netConn != nil {
log.Debugf("----read from turn conn...")
msg := &(*msgs)[0]
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
if err != nil {
return 0, err
}
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
numMsgs = 1
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
@@ -118,10 +95,7 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{
AddrPort: addrPort,
Conn: netConn,
}
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}

View File

@@ -1,38 +0,0 @@
package bind
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
relayConn net.PacketConn
}
func newReceiverCreator(iceBind *ICEBind) *receiverCreator {
return &receiverCreator{
iceBind: iceBind,
}
}
func (rc *receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn, nil)
}
func (rc *receiverCreator) CreateRelayReceiverFn(msgPool *sync.Pool) wgConn.ReceiveFunc {
if rc.relayConn == nil {
log.Debugf("-------rc.conn is nil")
return nil
}
return rc.iceBind.createIPv4ReceiverFn(msgPool, nil, nil, rc.relayConn)
}
func (rc *receiverCreator) setTurnConn(relayConn interface{}) {
log.Debug("------ SET TURN CONN")
rc.relayConn = relayConn.(net.PacketConn)
}

View File

@@ -150,10 +150,3 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey)
}
func (w *WGIface) SetRelayConn(conn interface{}) {
w.mu.Lock()
defer w.mu.Unlock()
w.tun.SetTurnConn(conn)
}

View File

@@ -85,27 +85,23 @@ func tunModuleIsLoaded() bool {
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
return false
/*
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false
}
if canCreateFakeWireGuardInterface() {
return true
}
if canCreateFakeWireGuardInterface() {
return true
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
loaded, err := tryToLoadModule("wireguard")
if err != nil {
log.Info(err)
return false
}
return loaded
*/
return loaded
}
func canCreateFakeWireGuardInterface() bool {

View File

@@ -15,5 +15,4 @@ type wgTunDevice interface {
DeviceName() string
Close() error
Wrapper() *DeviceWrapper // todo eliminate this function
SetTurnConn(conn interface{})
}

View File

@@ -28,14 +28,6 @@ type tunDevice struct {
configurer wgConfigurer
}
func (t *tunDevice) SetTurnConn(conn interface{}) {
t.iceBind.SetTurnConn(conn)
err := t.device.BindUpdate()
if err != nil {
log.Errorf("failed to update bind: %v", err)
}
}
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
return &tunDevice{
name: name,

View File

@@ -31,11 +31,6 @@ type tunKernelDevice struct {
udpMux *bind.UniversalUDPMuxDefault
}
func (t *tunKernelDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{

View File

@@ -30,11 +30,6 @@ type tunNetstackDevice struct {
configurer wgConfigurer
}
func (t *tunNetstackDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice {
return &tunNetstackDevice{
name: name,

View File

@@ -54,7 +54,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(device.LogLevelError, "[netbird] "),
device.NewLogger(device.LogLevelSilent, "[netbird] "),
)
err = t.assignAddr()
@@ -70,7 +70,6 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.configurer.close()
return nil, err
}
log.Debugf("configuration done")
return t.configurer, nil
}
@@ -126,14 +125,6 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
return t.wrapper
}
func (t *tunUSPDevice) SetTurnConn(conn interface{}) {
t.iceBind.SetTurnConn(conn)
err := t.device.BindUpdate()
if err != nil {
log.Errorf("failed to update bind: %v", err)
}
}
// assignAddr Adds IP address to the tunnel interface
func (t *tunUSPDevice) assignAddr() error {
link := newWGLink(t.name)

View File

@@ -10,8 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbnet "github.com/netbirdio/netbird/util/net"
)
type wgKernelConfigurer struct {
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
if err != nil {
return err
}
fwmark := nbnet.NetbirdFwmark
fwmark := getFwmark()
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,

View File

@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -251,7 +251,7 @@ var (
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}

View File

@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
@@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
}
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
log.Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err)
@@ -1861,9 +1862,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
AccountID: account.Id,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1907,7 +1909,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
users[userID] = NewOwnerUser(userID, accountID)
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}

View File

@@ -11,133 +11,134 @@ type Code struct {
Code string
}
// Existing consts must not be changed, as this will break the compatibility with the existing data
const (
// PeerAddedByUser indicates that a user added a new peer to the system
PeerAddedByUser Activity = iota
PeerAddedByUser Activity = 0
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
PeerAddedWithSetupKey
PeerAddedWithSetupKey Activity = 1
// UserJoined indicates that a new user joined the account
UserJoined
UserJoined Activity = 2
// UserInvited indicates that a new user was invited to join the account
UserInvited
UserInvited Activity = 3
// AccountCreated indicates that a new account has been created
AccountCreated
AccountCreated Activity = 4
// PeerRemovedByUser indicates that a user removed a peer from the system
PeerRemovedByUser
PeerRemovedByUser Activity = 5
// RuleAdded indicates that a user added a new rule
RuleAdded
RuleAdded Activity = 6
// RuleUpdated indicates that a user updated a rule
RuleUpdated
RuleUpdated Activity = 7
// RuleRemoved indicates that a user removed a rule
RuleRemoved
RuleRemoved Activity = 8
// PolicyAdded indicates that a user added a new policy
PolicyAdded
PolicyAdded Activity = 9
// PolicyUpdated indicates that a user updated a policy
PolicyUpdated
PolicyUpdated Activity = 10
// PolicyRemoved indicates that a user removed a policy
PolicyRemoved
PolicyRemoved Activity = 11
// SetupKeyCreated indicates that a user created a new setup key
SetupKeyCreated
SetupKeyCreated Activity = 12
// SetupKeyUpdated indicates that a user updated a setup key
SetupKeyUpdated
SetupKeyUpdated Activity = 13
// SetupKeyRevoked indicates that a user revoked a setup key
SetupKeyRevoked
SetupKeyRevoked Activity = 14
// SetupKeyOverused indicates that setup key usage exhausted
SetupKeyOverused
SetupKeyOverused Activity = 15
// GroupCreated indicates that a user created a group
GroupCreated
GroupCreated Activity = 16
// GroupUpdated indicates that a user updated a group
GroupUpdated
GroupUpdated Activity = 17
// GroupAddedToPeer indicates that a user added group to a peer
GroupAddedToPeer
GroupAddedToPeer Activity = 18
// GroupRemovedFromPeer indicates that a user removed peer group
GroupRemovedFromPeer
GroupRemovedFromPeer Activity = 19
// GroupAddedToUser indicates that a user added group to a user
GroupAddedToUser
GroupAddedToUser Activity = 20
// GroupRemovedFromUser indicates that a user removed a group from a user
GroupRemovedFromUser
GroupRemovedFromUser Activity = 21
// UserRoleUpdated indicates that a user changed the role of a user
UserRoleUpdated
UserRoleUpdated Activity = 22
// GroupAddedToSetupKey indicates that a user added group to a setup key
GroupAddedToSetupKey
GroupAddedToSetupKey Activity = 23
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
GroupRemovedFromSetupKey
GroupRemovedFromSetupKey Activity = 24
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
GroupAddedToDisabledManagementGroups
GroupAddedToDisabledManagementGroups Activity = 25
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
GroupRemovedFromDisabledManagementGroups
GroupRemovedFromDisabledManagementGroups Activity = 26
// RouteCreated indicates that a user created a route
RouteCreated
RouteCreated Activity = 27
// RouteRemoved indicates that a user deleted a route
RouteRemoved
RouteRemoved Activity = 28
// RouteUpdated indicates that a user updated a route
RouteUpdated
RouteUpdated Activity = 29
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
PeerSSHEnabled
PeerSSHEnabled Activity = 30
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
PeerSSHDisabled
PeerSSHDisabled Activity = 31
// PeerRenamed indicates that a user renamed a peer
PeerRenamed
PeerRenamed Activity = 32
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
PeerLoginExpirationEnabled
PeerLoginExpirationEnabled Activity = 33
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
PeerLoginExpirationDisabled
PeerLoginExpirationDisabled Activity = 34
// NameserverGroupCreated indicates that a user created a nameservers group
NameserverGroupCreated
NameserverGroupCreated Activity = 35
// NameserverGroupDeleted indicates that a user deleted a nameservers group
NameserverGroupDeleted
NameserverGroupDeleted Activity = 36
// NameserverGroupUpdated indicates that a user updated a nameservers group
NameserverGroupUpdated
NameserverGroupUpdated Activity = 37
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
AccountPeerLoginExpirationEnabled
AccountPeerLoginExpirationEnabled Activity = 38
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
AccountPeerLoginExpirationDisabled
AccountPeerLoginExpirationDisabled Activity = 39
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
AccountPeerLoginExpirationDurationUpdated
AccountPeerLoginExpirationDurationUpdated Activity = 40
// PersonalAccessTokenCreated indicates that a user created a personal access token
PersonalAccessTokenCreated
PersonalAccessTokenCreated Activity = 41
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
PersonalAccessTokenDeleted
PersonalAccessTokenDeleted Activity = 42
// ServiceUserCreated indicates that a user created a service user
ServiceUserCreated
ServiceUserCreated Activity = 43
// ServiceUserDeleted indicates that a user deleted a service user
ServiceUserDeleted
ServiceUserDeleted Activity = 44
// UserBlocked indicates that a user blocked another user
UserBlocked
UserBlocked Activity = 45
// UserUnblocked indicates that a user unblocked another user
UserUnblocked
UserUnblocked Activity = 46
// UserDeleted indicates that a user deleted another user
UserDeleted
UserDeleted Activity = 47
// GroupDeleted indicates that a user deleted group
GroupDeleted
GroupDeleted Activity = 48
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
UserLoggedInPeer
UserLoggedInPeer Activity = 49
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
PeerLoginExpired
PeerLoginExpired Activity = 50
// DashboardLogin indicates that the user logged in to the dashboard
DashboardLogin
DashboardLogin Activity = 51
// IntegrationCreated indicates that the user created an integration
IntegrationCreated
IntegrationCreated Activity = 52
// IntegrationUpdated indicates that the user updated an integration
IntegrationUpdated
IntegrationUpdated Activity = 53
// IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted
IntegrationDeleted Activity = 54
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled
AccountPeerApprovalEnabled Activity = 55
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled
AccountPeerApprovalDisabled Activity = 56
// PeerApproved indicates that the peer has been approved
PeerApproved
PeerApproved Activity = 57
// PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked
PeerApprovalRevoked Activity = 58
// TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole
TransferredOwnerRole Activity = 59
// PostureCheckCreated indicates that the user created a posture check
PostureCheckCreated
PostureCheckCreated Activity = 60
// PostureCheckUpdated indicates that the user updated a posture check
PostureCheckUpdated
PostureCheckUpdated Activity = 61
// PostureCheckDeleted indicates that the user deleted a posture check
PostureCheckDeleted
PostureCheckDeleted Activity = 62
)
var activityMap = map[Activity]Code{

View File

@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }

View File

@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user"),
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
},
DNSSettings: baseExistingDNSSettings,
}

View File

@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
},
}
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)

View File

@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {

View File

@@ -12,6 +12,7 @@ import (
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -38,7 +39,7 @@ type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

View File

@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
return nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true)

View File

@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
case <-ticker.C:
select {
case <-cancel:
log.Debugf("scheduled job %s was canceled, stop timer", ID)
log.Tracef("scheduled job %s was canceled, stop timer", ID)
ticker.Stop()
return
default:
log.Debugf("time to do a scheduled job %s", ID)
log.Tracef("time to do a scheduled job %s", ID)
}
runIn, reschedule := job()
if !reschedule {
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
log.Debugf("job %s is not scheduled to run again", ID)
log.Tracef("job %s is not scheduled to run again", ID)
ticker.Stop()
return
}
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
ticker.Reset(runIn)
}
case <-cancel:
log.Debugf("job %s was canceled, stopping timer", ID)
log.Tracef("job %s was canceled, stopping timer", ID)
ticker.Stop()
return
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock
}
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
// Get the reflect.Value of the records slice
v := reflect.ValueOf(records)
if v.Kind() != reflect.Slice {
return fmt.Errorf("provided input is not a slice")
}
// Insert records in batches
for i := 0; i < v.Len(); i += batchSize {
end := i + batchSize
if end > v.Len() {
end = v.Len()
}
// Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err
}
}
return nil
}
func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
// operate over a fresh copy as we will modify its fields
accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
//we need an explicit reference to the account for gorm
key.AccountID = accCopy.Id
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
}
for id, peer := range account.Peers {
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
//we need an explicit reference to the account for gorm
peer.AccountID = accCopy.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
}
for id, user := range account.Users {
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id
//we need an explicit reference to the account for gorm
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
accCopy.UsersG = append(accCopy.UsersG, *user)
}
for id, group := range account.Groups {
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
//we need an explicit reference to the account for gorm
group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
}
for id, route := range account.Routes {
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
//we need an explicit reference to the account for gorm
route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
//we need an explicit reference to the account for gorm
ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
Create(accCopy)
if result.Error != nil {
return result.Error
}
return nil
const batchSize = 500
err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
if err != nil {
return err
}
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err
}
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now()
account.UsersG = make([]User, 0, len(account.Users))
for id, user := range account.Users {
user.Id = id
//we need an explicit reference to an account as it is missing for some reason
user.AccountID = account.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {

View File

@@ -2,7 +2,12 @@ package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
admin := account.Users["testuser"]
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
HashedToken: "hashed token",
}}
err := store.SaveAccount(account)
require.NoError(t, err)
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
store := newSqliteStore(t)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user := NewAdminUser(testUserID, "account_id")
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account)
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}

View File

@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
accountID string) *User {
return &User{
Id: id,
Id: ID,
AccountID: accountID,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
func NewRegularUser(ID, accountID string) *User {
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
func NewAdminUser(ID, accountID string) *User {
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
func NewOwnerUser(ID, accountID string) *User {
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser

View File

@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
err := store.SaveAccount(account)
if err != nil {
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID)
user := NewAdminUser(mockUserID, mockAccountID)
assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID)
user = NewRegularUser(mockUserID, mockAccountID)
assert.False(t, user.HasAdminPower())
}
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(account)
if err != nil {

View File

@@ -56,7 +56,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
// MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
rosenpassPubKey []byte, rosenpassAddr, relayedAddress, serverRefIP string) (*proto.Message, error) {
rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
@@ -69,10 +69,6 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
Relay: &proto.Relay{
RelayedAddress: relayedAddress,
SrvRefAddress: serverRefIP,
},
},
}, nil
}

View File

@@ -215,21 +215,16 @@ type Body struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
// wgListenPort is an actual WireGuard listen port
// these will be set in OFFER, ANSWER, CANDIDATE only
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
// these will be set in OFFER, ANSWER, CANDIDATE only
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
Mode *Mode `protobuf:"bytes,5,opt,name=mode,proto3" json:"mode,omitempty"`
// featuresSupported list of supported features by the client of this protocol
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
// is this optional or mandatory?
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
Relay *Relay `protobuf:"bytes,8,opt,name=relay,proto3" json:"relay,omitempty"`
}
func (x *Body) Reset() {
@@ -313,18 +308,13 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
return nil
}
func (x *Body) GetRelay() *Relay {
if x != nil {
return x.Relay
}
return nil
}
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Direct *bool `protobuf:"varint,1,opt,name=direct,proto3,oneof" json:"direct,omitempty"`
}
func (x *Mode) Reset() {
@@ -359,59 +349,11 @@ func (*Mode) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{3}
}
type Relay struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RelayedAddress string `protobuf:"bytes,1,opt,name=relayedAddress,proto3" json:"relayedAddress,omitempty"`
SrvRefAddress string `protobuf:"bytes,2,opt,name=srvRefAddress,proto3" json:"srvRefAddress,omitempty"`
}
func (x *Relay) Reset() {
*x = Relay{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
func (x *Mode) GetDirect() bool {
if x != nil && x.Direct != nil {
return *x.Direct
}
}
func (x *Relay) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Relay) ProtoMessage() {}
func (x *Relay) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Relay.ProtoReflect.Descriptor instead.
func (*Relay) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
}
func (x *Relay) GetRelayedAddress() string {
if x != nil {
return x.RelayedAddress
}
return ""
}
func (x *Relay) GetSrvRefAddress() string {
if x != nil {
return x.SrvRefAddress
}
return ""
return false
}
type RosenpassConfig struct {
@@ -427,7 +369,7 @@ type RosenpassConfig struct {
func (x *RosenpassConfig) Reset() {
*x = RosenpassConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[5]
mi := &file_signalexchange_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -440,7 +382,7 @@ func (x *RosenpassConfig) String() string {
func (*RosenpassConfig) ProtoMessage() {}
func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[5]
mi := &file_signalexchange_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -453,7 +395,7 @@ func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
// Deprecated: Use RosenpassConfig.ProtoReflect.Descriptor instead.
func (*RosenpassConfig) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{5}
return file_signalexchange_proto_rawDescGZIP(), []int{4}
}
func (x *RosenpassConfig) GetRosenpassPubKey() []byte {
@@ -489,7 +431,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -509,39 +451,33 @@ var file_signalexchange_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2b, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18,
0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x52, 0x05, 0x72, 0x65,
0x6c, 0x61, 0x79, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f,
0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52,
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10,
0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x06, 0x0a, 0x04, 0x4d,
0x6f, 0x64, 0x65, 0x22, 0x55, 0x0a, 0x05, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x26, 0x0a, 0x0e,
0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64,
0x72, 0x65, 0x73, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x72, 0x76, 0x52, 0x65, 0x66, 0x41, 0x64,
0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x73, 0x72, 0x76,
0x52, 0x65, 0x66, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x6f,
0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a,
0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e,
0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04,
0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f,
0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69,
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63,
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e,
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d,
0x0a, 0x0f, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69,
0x67, 0x12, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75,
0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01,
0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67,
0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59,
0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12,
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -557,31 +493,29 @@ func file_signalexchange_proto_rawDescGZIP() []byte {
}
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_signalexchange_proto_goTypes = []interface{}{
(Body_Type)(0), // 0: signalexchange.Body.Type
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
(*Message)(nil), // 2: signalexchange.Message
(*Body)(nil), // 3: signalexchange.Body
(*Mode)(nil), // 4: signalexchange.Mode
(*Relay)(nil), // 5: signalexchange.Relay
(*RosenpassConfig)(nil), // 6: signalexchange.RosenpassConfig
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
}
var file_signalexchange_proto_depIdxs = []int32{
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
0, // 1: signalexchange.Body.type:type_name -> signalexchange.Body.Type
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
6, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
5, // 4: signalexchange.Body.relay:type_name -> signalexchange.Relay
1, // 5: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 8: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
7, // [7:9] is the sub-list for method output_type
5, // [5:7] is the sub-list for method input_type
5, // [5:5] is the sub-list for extension type_name
5, // [5:5] is the sub-list for extension extendee
0, // [0:5] is the sub-list for field type_name
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
6, // [6:8] is the sub-list for method output_type
4, // [4:6] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
}
func init() { file_signalexchange_proto_init() }
@@ -639,18 +573,6 @@ func file_signalexchange_proto_init() {
}
}
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Relay); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RosenpassConfig); i {
case 0:
return &v.state
@@ -663,13 +585,14 @@ func file_signalexchange_proto_init() {
}
}
}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_signalexchange_proto_rawDesc,
NumEnums: 1,
NumMessages: 6,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -49,33 +49,22 @@ message Body {
MODE = 4;
}
Type type = 1;
// these will be set in OFFER, ANSWER, CANDIDATE only
string payload = 2;
// wgListenPort is an actual WireGuard listen port
// these will be set in OFFER, ANSWER, CANDIDATE only
uint32 wgListenPort = 3;
// these will be set in OFFER, ANSWER, CANDIDATE only
string netBirdVersion = 4;
Mode mode = 5;
// featuresSupported list of supported features by the client of this protocol
repeated uint32 featuresSupported = 6;
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
// is this optional or mandatory?
RosenpassConfig rosenpassConfig = 7;
Relay relay = 8;
}
// Mode indicates a connection mode
message Mode {
}
message Relay {
string relayedAddress = 1;
string srvRefAddress = 2;
optional bool direct = 1;
}
message RosenpassConfig {

View File

@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
@@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
@@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr

View File

@@ -8,6 +8,7 @@ import (
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)

View File

@@ -1,10 +1,16 @@
package net
import "github.com/google/uuid"
import (
"os"
"github.com/google/uuid"
)
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// ConnectionID provides a globally unique identifier for network connections.
@@ -15,3 +21,7 @@ type ConnectionID string
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func CustomRoutingDisabled() bool {
return os.Getenv(envDisableCustomRouting) == "true"
}

View File

@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
setErr = SetSocketOpt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
return nil
}
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
return nil
}
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
}