mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 23:53:55 -04:00
Compare commits
29 Commits
feature/pe
...
feature/op
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
515ce9e3af | ||
|
|
89383b7f01 | ||
|
|
db34162733 | ||
|
|
bd761e2177 | ||
|
|
4e1b95a4c6 | ||
|
|
05993af7bf | ||
|
|
9d1cb00570 | ||
|
|
543731df45 | ||
|
|
e6628ec231 | ||
|
|
41d4dd2aff | ||
|
|
30bed57711 | ||
|
|
6960b68322 | ||
|
|
3b3aa18148 | ||
|
|
93045f3e3a | ||
|
|
fd3c1dea8e | ||
|
|
48aff7a26e | ||
|
|
83dfe8e3a3 | ||
|
|
38e10af2d9 | ||
|
|
99854a126a | ||
|
|
a75f982fcd | ||
|
|
e7a6483912 | ||
|
|
30ede299b8 | ||
|
|
e3b76448f3 | ||
|
|
e0de86d6c9 | ||
|
|
5204d07811 | ||
|
|
5ea24ba56e | ||
|
|
d30cf8706a | ||
|
|
15a2feb723 | ||
|
|
91b2f9fc51 |
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -33,6 +33,10 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
|
||||
@@ -143,7 +143,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
}
|
||||
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %#v", rule)
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +138,6 @@ type Engine struct {
|
||||
signalProbe *Probe
|
||||
relayProbe *Probe
|
||||
wgProbe *Probe
|
||||
turnRelay *relay.PermanentTurn
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -200,7 +199,7 @@ func NewEngineWithProbes(
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: &wgproxy.Factory{},
|
||||
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
|
||||
mgmProbe: mgmProbe,
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
@@ -453,19 +452,10 @@ func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKe
|
||||
t = sProto.Body_OFFER
|
||||
}
|
||||
|
||||
msg, err := signal.MarshalCredential(
|
||||
myKey,
|
||||
offerAnswer.WgListenPort,
|
||||
remoteKey, &signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
},
|
||||
t,
|
||||
offerAnswer.RosenpassPubKey,
|
||||
offerAnswer.RosenpassAddr,
|
||||
offerAnswer.RelayedAddr.String(),
|
||||
offerAnswer.RemoteAddr.String(),
|
||||
)
|
||||
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
|
||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -493,14 +483,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
turnRelay := relay.NewPermanentTurn(e.STUNs[0], e.TURNs[0])
|
||||
err = turnRelay.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("faile to open turn relay: %w", err)
|
||||
}
|
||||
e.turnRelay = turnRelay
|
||||
e.wgInterface.SetRelayConn(e.turnRelay.RelayConn())
|
||||
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
@@ -639,7 +621,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
var newTURNs []*stun.URI
|
||||
log.Debugf("got TURNs update from Management Service, updating")
|
||||
for _, turn := range turns {
|
||||
log.Debugf("-----updated Turn %v, %s, %s", turn.HostConfig.Uri, turn.User, turn.Password)
|
||||
url, err := stun.ParseURI(turn.HostConfig.Uri)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -649,6 +630,7 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
newTURNs = append(newTURNs, url)
|
||||
}
|
||||
e.TURNs = newTURNs
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -952,7 +934,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover, e.turnRelay)
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1018,17 +1000,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
|
||||
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
|
||||
}
|
||||
|
||||
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.OnRemoteOffer(peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -1038,8 +1009,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
Version: msg.GetBody().GetNetBirdVersion(),
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelayedAddr: relayedAddr,
|
||||
RemoteAddr: remoteAddr,
|
||||
})
|
||||
case sProto.Body_ANSWER:
|
||||
remoteCred, err := signal.UnMarshalCredential(msg)
|
||||
@@ -1055,17 +1024,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
|
||||
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
|
||||
}
|
||||
|
||||
relayedAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetRelayedAddress())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", msg.GetBody().GetRelay().GetSrvRefAddress())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.OnRemoteAnswer(peer.OfferAnswer{
|
||||
IceCredentials: peer.IceCredentials{
|
||||
UFrag: remoteCred.UFrag,
|
||||
@@ -1075,8 +1033,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
Version: msg.GetBody().GetNetBirdVersion(),
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassAddr: rosenpassAddr,
|
||||
RelayedAddr: relayedAddr,
|
||||
RemoteAddr: remoteAddr,
|
||||
})
|
||||
case sProto.Body_CANDIDATE:
|
||||
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
|
||||
@@ -1087,6 +1043,7 @@ func (e *Engine) receiveSignalEvents() {
|
||||
conn.OnRemoteCandidate(candidate)
|
||||
case sProto.Body_MODE:
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1158,8 +1115,6 @@ func (e *Engine) close() {
|
||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
e.turnRelay.Close()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
if e.dnsServer != nil {
|
||||
e.dnsServer.Stop()
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -13,7 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -93,10 +93,6 @@ type OfferAnswer struct {
|
||||
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
|
||||
// This value is the local Rosenpass server address when sending the message
|
||||
RosenpassAddr string
|
||||
|
||||
// Turn Relay
|
||||
RelayedAddr net.Addr
|
||||
RemoteAddr net.Addr
|
||||
}
|
||||
|
||||
// IceCredentials ICE protocol credentials struct
|
||||
@@ -145,11 +141,11 @@ type Conn struct {
|
||||
sentExtraSrflx bool
|
||||
|
||||
remoteEndpoint *net.UDPAddr
|
||||
remoteConn *ice.Conn
|
||||
|
||||
connID nbnet.ConnectionID
|
||||
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||
turnRelay *relay.PermanentTurn
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
@@ -180,7 +176,7 @@ func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, turnRelay *relay.PermanentTurn) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
|
||||
return &Conn{
|
||||
config: config,
|
||||
mu: sync.Mutex{},
|
||||
@@ -193,7 +189,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
turnRelay: turnRelay,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -217,7 +212,7 @@ func (conn *Conn) reCreateAgent() error {
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
Urls: conn.config.StunTurn,
|
||||
CandidateTypes: []ice.CandidateType{},
|
||||
CandidateTypes: conn.candidateTypes(),
|
||||
FailedTimeout: &failedTimeout,
|
||||
InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
|
||||
UDPMux: conn.config.UDPMux,
|
||||
@@ -267,6 +262,17 @@ func (conn *Conn) reCreateAgent() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) candidateTypes() []ice.CandidateType {
|
||||
if hasICEForceRelayConn() {
|
||||
return []ice.CandidateType{ice.CandidateTypeRelay}
|
||||
}
|
||||
// TODO: remove this once we have refactored userspace proxy into the bind package
|
||||
if runtime.GOOS == "ios" {
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
|
||||
}
|
||||
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
|
||||
}
|
||||
|
||||
// Open opens connection to the remote peer starting ICE candidate gathering process.
|
||||
// Blocks until connection has been closed or connection timeout.
|
||||
// ConnStatus will be set accordingly
|
||||
@@ -345,53 +351,42 @@ func (conn *Conn) Open() error {
|
||||
log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
|
||||
}
|
||||
|
||||
isControlling := conn.config.LocalKey < conn.config.Key
|
||||
if isControlling {
|
||||
log.Debugf("---- use this peer's tunr connection")
|
||||
err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to punch hole: %v", err)
|
||||
}
|
||||
addr, ok := remoteOfferAnswer.RemoteAddr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast addr to udp addr")
|
||||
}
|
||||
addr.Port = remoteOfferAnswer.WgListenPort
|
||||
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
if conn.wgProxy != nil {
|
||||
_ = conn.wgProxy.CloseConn()
|
||||
}
|
||||
// todo close
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
log.Debugf("---- use remote peer tunr connection")
|
||||
addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to cast addr to udp addr")
|
||||
}
|
||||
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
if conn.wgProxy != nil {
|
||||
_ = conn.wgProxy.CloseConn()
|
||||
}
|
||||
// todo close
|
||||
return err
|
||||
}
|
||||
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
/*
|
||||
remoteAddr, err := conn.configureConnection(remoteOfferAnswer.RelayedAddr, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
||||
remoteOfferAnswer.RosenpassAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*/
|
||||
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, addr.String())
|
||||
err = conn.agent.GatherCandidates()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// will block until connection succeeded
|
||||
// but it won't release if ICE Agent went into Disconnected or Failed state,
|
||||
// so we have to cancel it with the provided context once agent detected a broken connection
|
||||
isControlling := conn.config.LocalKey > conn.config.Key
|
||||
var remoteConn *ice.Conn
|
||||
if isControlling {
|
||||
remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// dynamically set remote WireGuard port is other side specified a different one from the default one
|
||||
remoteWgPort := iface.DefaultWgPort
|
||||
if remoteOfferAnswer.WgListenPort != 0 {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
|
||||
conn.remoteConn = remoteConn
|
||||
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
||||
remoteOfferAnswer.RosenpassAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
|
||||
|
||||
// wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
|
||||
select {
|
||||
case <-conn.closeCh:
|
||||
@@ -420,8 +415,25 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
pair, err := conn.agent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var endpoint net.Addr
|
||||
endpoint = remoteConn.RemoteAddr()
|
||||
if isRelayCandidate(pair.Local) {
|
||||
log.Debugf("setup relay connection")
|
||||
conn.wgProxy = conn.wgProxyFactory.GetProxy()
|
||||
endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
|
||||
go conn.punchRemoteWGPort(pair, remoteWgPort)
|
||||
endpoint = remoteConn.RemoteAddr()
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.remoteEndpoint = endpointUdpAddr
|
||||
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
@@ -433,7 +445,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
}
|
||||
}
|
||||
|
||||
err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||
if err != nil {
|
||||
if conn.wgProxy != nil {
|
||||
_ = conn.wgProxy.CloseConn()
|
||||
@@ -442,33 +454,31 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
}
|
||||
|
||||
conn.status = StatusConnected
|
||||
/*
|
||||
rosenpassEnabled := false
|
||||
if remoteRosenpassPubKey != nil {
|
||||
rosenpassEnabled = true
|
||||
}
|
||||
rosenpassEnabled := false
|
||||
if remoteRosenpassPubKey != nil {
|
||||
rosenpassEnabled = true
|
||||
}
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Local.Port()),
|
||||
Direct: !isRelayCandidate(pair.Local),
|
||||
RosenpassEnabled: rosenpassEnabled,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
}
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.status,
|
||||
ConnStatusUpdate: time.Now(),
|
||||
LocalIceCandidateType: pair.Local.Type().String(),
|
||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
Direct: !isRelayCandidate(pair.Local),
|
||||
RosenpassEnabled: rosenpassEnabled,
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
|
||||
peerState.Relayed = true
|
||||
}
|
||||
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||
}
|
||||
*/
|
||||
err = conn.statusRecorder.UpdatePeerState(peerState)
|
||||
if err != nil {
|
||||
log.Warnf("unable to save peer's state, got error: %v", err)
|
||||
}
|
||||
|
||||
_, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
|
||||
if err != nil {
|
||||
@@ -670,8 +680,6 @@ func (conn *Conn) sendAnswer() error {
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: conn.config.RosenpassPubKey,
|
||||
RosenpassAddr: conn.config.RosenpassAddr,
|
||||
RelayedAddr: conn.turnRelay.RelayedAddress(),
|
||||
RemoteAddr: conn.turnRelay.SrvRefAddr(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -695,8 +703,6 @@ func (conn *Conn) sendOffer() error {
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: conn.config.RosenpassPubKey,
|
||||
RosenpassAddr: conn.config.RosenpassAddr,
|
||||
RelayedAddr: conn.turnRelay.RelayedAddress(),
|
||||
RemoteAddr: conn.turnRelay.SrvRefAddr(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -736,10 +742,6 @@ func (conn *Conn) Status() ConnStatus {
|
||||
return conn.status
|
||||
}
|
||||
|
||||
func (conn *Conn) OnRemoteRelayRequest(relayedAddr string, remoteIP string) {
|
||||
|
||||
}
|
||||
|
||||
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 800
|
||||
testFile = "/tmp/1MB"
|
||||
)
|
||||
|
||||
type Speed struct {
|
||||
}
|
||||
|
||||
func NewSpeed() *Speed {
|
||||
return &Speed{}
|
||||
}
|
||||
|
||||
func (s *Speed) ReceiveFileFromAddr(remoteAddr net.Addr) error {
|
||||
pc, err := net.ListenPacket("udp4", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
log.Errorf("failed to lisen: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
log.Debugf("--- sending initial message to: %s", remoteAddr.String())
|
||||
_, err = pc.WriteTo([]byte("hey, I am the receiver"), remoteAddr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send initial msg: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
return s.receiveFile(pc)
|
||||
}
|
||||
|
||||
func (s *Speed) ReceiveFileFromPC(pc net.PacketConn) error {
|
||||
return s.receiveFile(pc)
|
||||
}
|
||||
|
||||
func (s *Speed) receiveFile(pc net.PacketConn) error {
|
||||
log.Debugf("--- start to receive file...")
|
||||
file, err := os.OpenFile(fmt.Sprintf("%s.cp", testFile), os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open file: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
_ = file.Truncate(0)
|
||||
defer file.Close()
|
||||
|
||||
buffer := make([]byte, bufferSize)
|
||||
for {
|
||||
n, addr, err := pc.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read from connection: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
n, err = file.Write(buffer[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to write to file: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = pc.WriteTo([]byte("ack"), addr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send ack: %s", err.Error())
|
||||
}
|
||||
|
||||
log.Debugf("received %d bytes from %s", n, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Speed) SendFileToPC(relayConn net.PacketConn) error {
|
||||
buf := make([]byte, bufferSize)
|
||||
log.Debugf("--- wait for initial message")
|
||||
n, rAddr, err := relayConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read from connection: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
log.Errorf("received initial msg %d bytes (%s), addr %s", n, string(buf[:n]), rAddr.String())
|
||||
return s.sendFile(relayConn, rAddr)
|
||||
}
|
||||
|
||||
func (s *Speed) SendFileToAddr(addr net.Addr) error {
|
||||
pc, err := net.ListenPacket("udp4", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
log.Errorf("failed to lisen: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
return s.sendFile(pc, addr)
|
||||
}
|
||||
|
||||
func (s *Speed) sendFile(conn net.PacketConn, rAddr net.Addr) error {
|
||||
log.Debugf("--- start to send file...")
|
||||
file, err := os.Open(testFile)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
return nil
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
start := time.Now()
|
||||
sent := 0
|
||||
|
||||
for {
|
||||
n, err := file.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
n, err = conn.WriteTo(buf[:n], rAddr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write to connection: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
sent += n
|
||||
log.Debugf("sent %d bytes, (%d) to %s", n, sent, rAddr.String())
|
||||
|
||||
// wait for ack
|
||||
_, _, err = conn.ReadFrom(make([]byte, bufferSize))
|
||||
if err != nil {
|
||||
log.Errorf("failed to read from connection: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
log.Infof("sent %d bytes, troughtput: %f MB/s", sent, float64(sent)/1024/1024/elapsed.Seconds())
|
||||
return nil
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/stun/v2"
|
||||
"github.com/pion/turn/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type PermanentTurn struct {
|
||||
stunURI *stun.URI
|
||||
turnURI *stun.URI
|
||||
|
||||
stunConn net.PacketConn
|
||||
turnClient *turn.Client
|
||||
turnClientListenLock sync.Mutex
|
||||
relayConn net.PacketConn // represents the remote socket.
|
||||
srvReflexiveAddress *net.UDPAddr
|
||||
}
|
||||
|
||||
func NewPermanentTurn(stunURL, turnURL *stun.URI) *PermanentTurn {
|
||||
return &PermanentTurn{
|
||||
stunURI: stunURL,
|
||||
turnURI: turnURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) Open() error {
|
||||
stunConn, err := net.ListenPacket("udp4", "0.0.0.0:0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.stunConn = stunConn
|
||||
|
||||
cfg := &turn.ClientConfig{
|
||||
STUNServerAddr: toURL(r.stunURI),
|
||||
TURNServerAddr: toURL(r.turnURI),
|
||||
Conn: stunConn,
|
||||
Username: r.turnURI.Username,
|
||||
Password: r.turnURI.Password,
|
||||
LoggerFactory: logging.NewDefaultLoggerFactory(),
|
||||
}
|
||||
|
||||
client, err := turn.NewClient(cfg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create turn client: %v", err)
|
||||
return err
|
||||
}
|
||||
r.turnClient = client
|
||||
err = r.turnClient.Listen()
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen turn client: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
relayConn, err := client.Allocate()
|
||||
if err != nil {
|
||||
log.Errorf("failed to allocate relay connection: %v", err)
|
||||
return err
|
||||
}
|
||||
r.relayConn = relayConn
|
||||
|
||||
srvReflexiveAddress, err := r.discoverPublicIPByStun()
|
||||
if err != nil {
|
||||
log.Errorf("failed to discover public IP: %v", err)
|
||||
return err
|
||||
}
|
||||
r.srvReflexiveAddress = srvReflexiveAddress
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) RelayedAddress() net.Addr {
|
||||
return r.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) SrvRefAddr() net.Addr {
|
||||
return r.srvReflexiveAddress
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error {
|
||||
/*
|
||||
err := r.turnClient.CreatePermission(mappedAddr)
|
||||
if err != nil {
|
||||
log.Errorf("---- failed to create permission: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
|
||||
stun.Fingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("--- failed to build stun message: %v", err)
|
||||
return nil
|
||||
}
|
||||
_, err = r.relayConn.WriteTo(msg.Raw, mappedAddr)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write to relay conn: %v", err)
|
||||
return err
|
||||
}
|
||||
*/
|
||||
_, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) RelayConn() net.PacketConn {
|
||||
return r.relayConn
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) Close() {
|
||||
r.turnClient.Close()
|
||||
|
||||
err := r.relayConn.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close relayConn: %s", err.Error())
|
||||
}
|
||||
|
||||
err = r.stunConn.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close stunConn: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) {
|
||||
addr, err := r.turnClient.SendBindingRequest()
|
||||
if err != nil {
|
||||
log.Errorf("failed to send binding request: %v", err)
|
||||
return nil, err
|
||||
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to cast addr to udp addr")
|
||||
}
|
||||
|
||||
return udpAddr, nil
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) discoverPublicIPByStun() (*net.UDPAddr, error) {
|
||||
c, err := stun.DialURI(r.stunURI, &stun.DialConfig{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
message := stun.MustBuild(stun.TransactionID, stun.BindingRequest)
|
||||
var addr *net.UDPAddr
|
||||
err = c.Do(message, func(res stun.Event) {
|
||||
if res.Error != nil {
|
||||
panic(res.Error)
|
||||
}
|
||||
var xorAddr stun.XORMappedAddress
|
||||
if err := xorAddr.GetFrom(res.Message); err != nil {
|
||||
log.Errorf("failed to get xor address: %v", err)
|
||||
return
|
||||
}
|
||||
addr = &net.UDPAddr{
|
||||
IP: xorAddr.IP,
|
||||
Port: xorAddr.Port,
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func (r *PermanentTurn) listen() {
|
||||
if !r.turnClientListenLock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer r.turnClientListenLock.Unlock()
|
||||
|
||||
buf := make([]byte, math.MaxUint16)
|
||||
for {
|
||||
n, from, err := r.stunConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read from stun conn. Exiting loop %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
_, err = r.turnClient.HandleInbound(buf[:n], from)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to handle inbound turn message: %s. Exiting loop", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func toURL(uri *stun.URI) string {
|
||||
return fmt.Sprintf("%s:%d", uri.Host, uri.Port)
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
userName = "1714092678"
|
||||
password = "8PEprGKo+UARpYpQOulNz3H24dI="
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestMyTurnUpload(t *testing.T) {
|
||||
turnURI, err := stun.ParseURI("turn:api.stage.netbird.io:3478?transport=udp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse stun url: %v", err)
|
||||
}
|
||||
turnURI.Username = userName
|
||||
turnURI.Password = password
|
||||
|
||||
stunURI, err := stun.ParseURI("stun:api.stage.netbird.io:3478")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse stun url: %v", err)
|
||||
}
|
||||
turnRelayA := NewPermanentTurn(stunURI, turnURI)
|
||||
err = turnRelayA.Open()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open turn relay: %v", err)
|
||||
}
|
||||
defer turnRelayA.Close()
|
||||
|
||||
turnRelayB := NewPermanentTurn(stunURI, turnURI)
|
||||
peerBAddr, err := turnRelayB.discoverPublicIPByStun()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to discover public ip: %v", err)
|
||||
}
|
||||
|
||||
err = turnRelayA.PunchHole(peerBAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to punch hole: %v", err)
|
||||
}
|
||||
|
||||
// at this point, the relayed side should be established
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
speedB := NewSpeed()
|
||||
go func() {
|
||||
err := speedB.ReceiveFileFromAddr(turnRelayA.relayConn.LocalAddr())
|
||||
if err != nil {
|
||||
log.Errorf("failed to receive file: %v", err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
speedA := NewSpeed()
|
||||
go func() {
|
||||
err := speedA.SendFileToPC(turnRelayA.relayConn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send file: %v", err)
|
||||
}
|
||||
log.Debugf("file sent")
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestMyTurnDownload(t *testing.T) {
|
||||
turnURI, err := stun.ParseURI("turn:api.stage.netbird.io:3478?transport=udp")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse stun url: %v", err)
|
||||
}
|
||||
turnURI.Username = "1714016034"
|
||||
turnURI.Password = "oDpL6tDu0d+xcO3rQnHoEvbcS/Q="
|
||||
|
||||
stunURI, err := stun.ParseURI("stun:api.stage.netbird.io:3478")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse stun url: %v", err)
|
||||
}
|
||||
turnRelayA := NewPermanentTurn(stunURI, turnURI)
|
||||
err = turnRelayA.Open()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open turn relay: %v", err)
|
||||
}
|
||||
defer turnRelayA.Close()
|
||||
|
||||
turnRelayB := NewPermanentTurn(stunURI, turnURI)
|
||||
peerBAddr, err := turnRelayB.discoverPublicIPByStun()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to discover public ip: %v", err)
|
||||
}
|
||||
|
||||
err = turnRelayA.PunchHole(peerBAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to punch hole: %v", err)
|
||||
}
|
||||
|
||||
// at this point, the relayed side should be established
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
speedB := NewSpeed()
|
||||
go func() {
|
||||
err := speedB.SendFileToAddr(turnRelayA.relayConn.LocalAddr())
|
||||
if err != nil {
|
||||
log.Errorf("failed to receive file: %v", err)
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
speedA := NewSpeed()
|
||||
go func() {
|
||||
err := speedA.ReceiveFileFromPC(turnRelayA.relayConn)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send file: %v", err)
|
||||
}
|
||||
log.Debugf("file sent")
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -68,6 +69,10 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Warnf("Failed cleaning up routing: %v", err)
|
||||
}
|
||||
@@ -99,11 +104,15 @@ func (m *DefaultManager) Stop() {
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if err := cleanupRouting(); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
log.Info("Routing cleanup complete")
|
||||
}
|
||||
}
|
||||
|
||||
m.ctx = nil
|
||||
}
|
||||
|
||||
@@ -210,9 +219,11 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
||||
}
|
||||
|
||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "windows", "darwin":
|
||||
return true
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
switch runtime.GOOS {
|
||||
case "linux", "windows", "darwin":
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
||||
|
||||
@@ -4,14 +4,14 @@ package routemanager
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -32,19 +32,31 @@ const (
|
||||
rtTablesPath = "/etc/iproute2/rt_tables"
|
||||
|
||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
||||
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||
ipv4ForwardingPath = "net.ipv4.ip_forward"
|
||||
|
||||
rpFilterPath = "net.ipv4.conf.all.rp_filter"
|
||||
rpFilterInterfacePath = "net.ipv4.conf.%s.rp_filter"
|
||||
srcValidMarkPath = "net.ipv4.conf.all.src_valid_mark"
|
||||
)
|
||||
|
||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
||||
|
||||
var routeManager = &RouteManager{}
|
||||
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
|
||||
|
||||
// originalSysctl stores the original sysctl values before they are modified
|
||||
var originalSysctl map[string]int
|
||||
|
||||
// determines whether to use the legacy routing setup
|
||||
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
|
||||
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
|
||||
var sysctlFailed bool
|
||||
|
||||
type ruleParams struct {
|
||||
priority int
|
||||
fwmark int
|
||||
tableID int
|
||||
family int
|
||||
priority int
|
||||
invert bool
|
||||
suppressPrefix int
|
||||
description string
|
||||
@@ -52,10 +64,10 @@ type ruleParams struct {
|
||||
|
||||
func getSetupRules() []ruleParams {
|
||||
return []ruleParams{
|
||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
|
||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
|
||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
|
||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
|
||||
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
|
||||
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
|
||||
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,8 +81,6 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
//
|
||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
||||
if isLegacy {
|
||||
log.Infof("Using legacy routing setup")
|
||||
@@ -81,6 +91,13 @@ func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.Before
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := setupSysctl(wgIface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
||||
@@ -123,11 +140,17 @@ func cleanupRouting() error {
|
||||
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
if err := removeRule(rule); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := cleanupSysctl(originalSysctl); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
|
||||
}
|
||||
originalSysctl = nil
|
||||
sysctlFailed = false
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
@@ -144,6 +167,10 @@ func addVPNRoute(prefix netip.Prefix, intf string) error {
|
||||
return genericAddVPNRoute(prefix, intf)
|
||||
}
|
||||
|
||||
if sysctlFailed && (prefix == defaultv4 || prefix == defaultv6) {
|
||||
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
|
||||
}
|
||||
|
||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
||||
|
||||
// TODO remove this once we have ipv6 support
|
||||
@@ -336,22 +363,8 @@ func flushRoutes(tableID, family int) error {
|
||||
}
|
||||
|
||||
func enableIPForwarding() error {
|
||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
|
||||
}
|
||||
|
||||
// check if it is already enabled
|
||||
// see more: https://github.com/netbirdio/netbird/issues/872
|
||||
if len(bytes) > 0 && bytes[0] == 49 {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
|
||||
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
|
||||
}
|
||||
return nil
|
||||
_, err := setSysctl(ipv4ForwardingPath, 1, false)
|
||||
return err
|
||||
}
|
||||
|
||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
||||
@@ -429,7 +442,7 @@ func addRule(params ruleParams) error {
|
||||
rule.Invert = params.invert
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("add routing rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -446,43 +459,13 @@ func removeRule(params ruleParams) error {
|
||||
rule.Priority = params.priority
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil {
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeAllRules(params ruleParams) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
done <- ctx.Err()
|
||||
return
|
||||
}
|
||||
if err := removeRule(params); err != nil {
|
||||
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// addNextHop adds the gateway and device to the route.
|
||||
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
||||
if addr.IsValid() {
|
||||
@@ -509,3 +492,83 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
}
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
// setupSysctl configures sysctl settings for RP filtering and source validation.
|
||||
func setupSysctl(wgIface *iface.WGIface) (map[string]int, error) {
|
||||
keys := map[string]int{}
|
||||
var result *multierror.Error
|
||||
|
||||
oldVal, err := setSysctl(srcValidMarkPath, 1, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[srcValidMarkPath] = oldVal
|
||||
}
|
||||
|
||||
oldVal, err = setSysctl(rpFilterPath, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[rpFilterPath] = oldVal
|
||||
}
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("list interfaces: %w", err))
|
||||
}
|
||||
|
||||
for _, intf := range interfaces {
|
||||
if intf.Name == "lo" || wgIface != nil && intf.Name == wgIface.Name() {
|
||||
continue
|
||||
}
|
||||
|
||||
i := fmt.Sprintf(rpFilterInterfacePath, intf.Name)
|
||||
oldVal, err := setSysctl(i, 2, true)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
} else {
|
||||
keys[i] = oldVal
|
||||
}
|
||||
}
|
||||
|
||||
return keys, result.ErrorOrNil()
|
||||
}
|
||||
|
||||
// setSysctl sets a sysctl configuration, if onlyIfOne is true it will only set the new value if it's set to 1
|
||||
func setSysctl(key string, desiredValue int, onlyIfOne bool) (int, error) {
|
||||
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||
currentValue, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||
}
|
||||
|
||||
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||
if err != nil && len(currentValue) > 0 {
|
||||
return -1, fmt.Errorf("convert current desiredValue to int: %w", err)
|
||||
}
|
||||
|
||||
if currentV == desiredValue || onlyIfOne && currentV != 1 {
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
//nolint:gosec
|
||||
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||
return currentV, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||
}
|
||||
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||
|
||||
return currentV, nil
|
||||
}
|
||||
|
||||
func cleanupSysctl(originalSettings map[string]int) error {
|
||||
var result *multierror.Error
|
||||
|
||||
for key, value := range originalSettings {
|
||||
_, err := setSysctl(key, value, false)
|
||||
if err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
|
||||
err = wgInterface.Create()
|
||||
require.NoError(t, err, "should create testing wireguard interface")
|
||||
_, _, err = setupRouting(nil, nil)
|
||||
_, _, err = setupRouting(nil, wgInterface)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cleanupRouting())
|
||||
|
||||
@@ -73,7 +73,7 @@ func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx s
|
||||
}
|
||||
|
||||
script := fmt.Sprintf(
|
||||
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`,
|
||||
`%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop -PolicyStore ActiveStore`,
|
||||
psCmd, addressFamily, destinationPrefix,
|
||||
)
|
||||
|
||||
|
||||
@@ -230,7 +230,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -60,7 +60,7 @@ require (
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -172,7 +172,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2023
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
8
go.sum
8
go.sum
@@ -383,14 +383,14 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
|
||||
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01 h1:Fu9fq0ndfKVuFTEwbc8Etqui10BOkcMTv0UqcMy0RuY=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5 h1:m48qfB2ILlFx3oZlw7aEeD+V6vXnMb0hNwmDCtdcgv0=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240422165616-c6832bb477d5/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
|
||||
@@ -13,6 +13,14 @@ import (
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
}
|
||||
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
|
||||
@@ -20,8 +28,6 @@ type ICEBind struct {
|
||||
|
||||
transportNet transport.Net
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
|
||||
receiverCreator *receiverCreator
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net) *ICEBind {
|
||||
@@ -29,9 +35,9 @@ func NewICEBind(transportNet transport.Net) *ICEBind {
|
||||
transportNet: transportNet,
|
||||
}
|
||||
|
||||
rc := newReceiverCreator(ib)
|
||||
ib.receiverCreator = rc
|
||||
|
||||
rc := receiverCreator{
|
||||
ib,
|
||||
}
|
||||
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
||||
return ib
|
||||
}
|
||||
@@ -47,22 +53,16 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) SetTurnConn(conn interface{}) {
|
||||
s.receiverCreator.setTurnConn(conn)
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn, netConn net.PacketConn) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
if conn != nil {
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
},
|
||||
)
|
||||
}
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
@@ -71,40 +71,17 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
if netConn != nil {
|
||||
log.Debugf("----read from turn conn...")
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
|
||||
numMsgs = 1
|
||||
} else {
|
||||
log.Debugf("----read from pc...")
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
if netConn != nil {
|
||||
log.Debugf("----read from turn conn...")
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
|
||||
numMsgs = 1
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
@@ -118,10 +95,7 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{
|
||||
AddrPort: addrPort,
|
||||
Conn: netConn,
|
||||
}
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
relayConn net.PacketConn
|
||||
}
|
||||
|
||||
func newReceiverCreator(iceBind *ICEBind) *receiverCreator {
|
||||
return &receiverCreator{
|
||||
iceBind: iceBind,
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn, nil)
|
||||
}
|
||||
|
||||
func (rc *receiverCreator) CreateRelayReceiverFn(msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
if rc.relayConn == nil {
|
||||
log.Debugf("-------rc.conn is nil")
|
||||
return nil
|
||||
}
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, nil, nil, rc.relayConn)
|
||||
}
|
||||
|
||||
func (rc *receiverCreator) setTurnConn(relayConn interface{}) {
|
||||
log.Debug("------ SET TURN CONN")
|
||||
rc.relayConn = relayConn.(net.PacketConn)
|
||||
}
|
||||
@@ -150,10 +150,3 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
|
||||
func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
|
||||
return w.configurer.getStats(peerKey)
|
||||
}
|
||||
|
||||
func (w *WGIface) SetRelayConn(conn interface{}) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
w.tun.SetTurnConn(conn)
|
||||
}
|
||||
|
||||
@@ -85,27 +85,23 @@ func tunModuleIsLoaded() bool {
|
||||
|
||||
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
|
||||
func WireGuardModuleIsLoaded() bool {
|
||||
return false
|
||||
|
||||
/*
|
||||
if os.Getenv(envDisableWireGuardKernel) == "true" {
|
||||
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
|
||||
return false
|
||||
}
|
||||
if os.Getenv(envDisableWireGuardKernel) == "true" {
|
||||
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
|
||||
return false
|
||||
}
|
||||
|
||||
if canCreateFakeWireGuardInterface() {
|
||||
return true
|
||||
}
|
||||
if canCreateFakeWireGuardInterface() {
|
||||
return true
|
||||
}
|
||||
|
||||
loaded, err := tryToLoadModule("wireguard")
|
||||
if err != nil {
|
||||
log.Info(err)
|
||||
return false
|
||||
}
|
||||
loaded, err := tryToLoadModule("wireguard")
|
||||
if err != nil {
|
||||
log.Info(err)
|
||||
return false
|
||||
}
|
||||
|
||||
return loaded
|
||||
|
||||
*/
|
||||
return loaded
|
||||
}
|
||||
|
||||
func canCreateFakeWireGuardInterface() bool {
|
||||
|
||||
@@ -15,5 +15,4 @@ type wgTunDevice interface {
|
||||
DeviceName() string
|
||||
Close() error
|
||||
Wrapper() *DeviceWrapper // todo eliminate this function
|
||||
SetTurnConn(conn interface{})
|
||||
}
|
||||
|
||||
@@ -28,14 +28,6 @@ type tunDevice struct {
|
||||
configurer wgConfigurer
|
||||
}
|
||||
|
||||
func (t *tunDevice) SetTurnConn(conn interface{}) {
|
||||
t.iceBind.SetTurnConn(conn)
|
||||
err := t.device.BindUpdate()
|
||||
if err != nil {
|
||||
log.Errorf("failed to update bind: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
|
||||
return &tunDevice{
|
||||
name: name,
|
||||
|
||||
@@ -31,11 +31,6 @@ type tunKernelDevice struct {
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
}
|
||||
|
||||
func (t *tunKernelDevice) SetTurnConn(interface{}) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &tunKernelDevice{
|
||||
|
||||
@@ -30,11 +30,6 @@ type tunNetstackDevice struct {
|
||||
configurer wgConfigurer
|
||||
}
|
||||
|
||||
func (t *tunNetstackDevice) SetTurnConn(interface{}) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice {
|
||||
return &tunNetstackDevice{
|
||||
name: name,
|
||||
|
||||
@@ -54,7 +54,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
|
||||
t.device = device.NewDevice(
|
||||
t.wrapper,
|
||||
t.iceBind,
|
||||
device.NewLogger(device.LogLevelError, "[netbird] "),
|
||||
device.NewLogger(device.LogLevelSilent, "[netbird] "),
|
||||
)
|
||||
|
||||
err = t.assignAddr()
|
||||
@@ -70,7 +70,6 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
|
||||
t.configurer.close()
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("configuration done")
|
||||
return t.configurer, nil
|
||||
}
|
||||
|
||||
@@ -126,14 +125,6 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
|
||||
return t.wrapper
|
||||
}
|
||||
|
||||
func (t *tunUSPDevice) SetTurnConn(conn interface{}) {
|
||||
t.iceBind.SetTurnConn(conn)
|
||||
err := t.device.BindUpdate()
|
||||
if err != nil {
|
||||
log.Errorf("failed to update bind: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// assignAddr Adds IP address to the tunnel interface
|
||||
func (t *tunUSPDevice) assignAddr() error {
|
||||
link := newWGLink(t.name)
|
||||
|
||||
@@ -10,8 +10,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type wgKernelConfigurer struct {
|
||||
@@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fwmark := nbnet.NetbirdFwmark
|
||||
fwmark := getFwmark()
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ReplacePeers: true,
|
||||
|
||||
@@ -349,7 +349,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
}
|
||||
|
||||
func getFwmark() int {
|
||||
if runtime.GOOS == "linux" {
|
||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
||||
return nbnet.NetbirdFwmark
|
||||
}
|
||||
return 0
|
||||
|
||||
@@ -251,7 +251,7 @@ var (
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -1473,7 +1473,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
|
||||
// if domain already has a primary account, add regular user
|
||||
if domainAcc != nil {
|
||||
account = domainAcc
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1849,6 +1849,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
||||
log.Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
updatedAccount, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get account %s: %v", accountID, err)
|
||||
@@ -1861,9 +1862,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &nbgroup.Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
AccountID: account.Id,
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||
@@ -1907,7 +1909,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
||||
routes := make(map[string]*route.Route)
|
||||
setupKeys := map[string]*SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
users[userID] = NewOwnerUser(userID)
|
||||
users[userID] = NewOwnerUser(userID, accountID)
|
||||
dnsSettings := DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
}
|
||||
|
||||
@@ -11,133 +11,134 @@ type Code struct {
|
||||
Code string
|
||||
}
|
||||
|
||||
// Existing consts must not be changed, as this will break the compatibility with the existing data
|
||||
const (
|
||||
// PeerAddedByUser indicates that a user added a new peer to the system
|
||||
PeerAddedByUser Activity = iota
|
||||
PeerAddedByUser Activity = 0
|
||||
// PeerAddedWithSetupKey indicates that a new peer joined the system using a setup key
|
||||
PeerAddedWithSetupKey
|
||||
PeerAddedWithSetupKey Activity = 1
|
||||
// UserJoined indicates that a new user joined the account
|
||||
UserJoined
|
||||
UserJoined Activity = 2
|
||||
// UserInvited indicates that a new user was invited to join the account
|
||||
UserInvited
|
||||
UserInvited Activity = 3
|
||||
// AccountCreated indicates that a new account has been created
|
||||
AccountCreated
|
||||
AccountCreated Activity = 4
|
||||
// PeerRemovedByUser indicates that a user removed a peer from the system
|
||||
PeerRemovedByUser
|
||||
PeerRemovedByUser Activity = 5
|
||||
// RuleAdded indicates that a user added a new rule
|
||||
RuleAdded
|
||||
RuleAdded Activity = 6
|
||||
// RuleUpdated indicates that a user updated a rule
|
||||
RuleUpdated
|
||||
RuleUpdated Activity = 7
|
||||
// RuleRemoved indicates that a user removed a rule
|
||||
RuleRemoved
|
||||
RuleRemoved Activity = 8
|
||||
// PolicyAdded indicates that a user added a new policy
|
||||
PolicyAdded
|
||||
PolicyAdded Activity = 9
|
||||
// PolicyUpdated indicates that a user updated a policy
|
||||
PolicyUpdated
|
||||
PolicyUpdated Activity = 10
|
||||
// PolicyRemoved indicates that a user removed a policy
|
||||
PolicyRemoved
|
||||
PolicyRemoved Activity = 11
|
||||
// SetupKeyCreated indicates that a user created a new setup key
|
||||
SetupKeyCreated
|
||||
SetupKeyCreated Activity = 12
|
||||
// SetupKeyUpdated indicates that a user updated a setup key
|
||||
SetupKeyUpdated
|
||||
SetupKeyUpdated Activity = 13
|
||||
// SetupKeyRevoked indicates that a user revoked a setup key
|
||||
SetupKeyRevoked
|
||||
SetupKeyRevoked Activity = 14
|
||||
// SetupKeyOverused indicates that setup key usage exhausted
|
||||
SetupKeyOverused
|
||||
SetupKeyOverused Activity = 15
|
||||
// GroupCreated indicates that a user created a group
|
||||
GroupCreated
|
||||
GroupCreated Activity = 16
|
||||
// GroupUpdated indicates that a user updated a group
|
||||
GroupUpdated
|
||||
GroupUpdated Activity = 17
|
||||
// GroupAddedToPeer indicates that a user added group to a peer
|
||||
GroupAddedToPeer
|
||||
GroupAddedToPeer Activity = 18
|
||||
// GroupRemovedFromPeer indicates that a user removed peer group
|
||||
GroupRemovedFromPeer
|
||||
GroupRemovedFromPeer Activity = 19
|
||||
// GroupAddedToUser indicates that a user added group to a user
|
||||
GroupAddedToUser
|
||||
GroupAddedToUser Activity = 20
|
||||
// GroupRemovedFromUser indicates that a user removed a group from a user
|
||||
GroupRemovedFromUser
|
||||
GroupRemovedFromUser Activity = 21
|
||||
// UserRoleUpdated indicates that a user changed the role of a user
|
||||
UserRoleUpdated
|
||||
UserRoleUpdated Activity = 22
|
||||
// GroupAddedToSetupKey indicates that a user added group to a setup key
|
||||
GroupAddedToSetupKey
|
||||
GroupAddedToSetupKey Activity = 23
|
||||
// GroupRemovedFromSetupKey indicates that a user removed a group from a setup key
|
||||
GroupRemovedFromSetupKey
|
||||
GroupRemovedFromSetupKey Activity = 24
|
||||
// GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups
|
||||
GroupAddedToDisabledManagementGroups
|
||||
GroupAddedToDisabledManagementGroups Activity = 25
|
||||
// GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups
|
||||
GroupRemovedFromDisabledManagementGroups
|
||||
GroupRemovedFromDisabledManagementGroups Activity = 26
|
||||
// RouteCreated indicates that a user created a route
|
||||
RouteCreated
|
||||
RouteCreated Activity = 27
|
||||
// RouteRemoved indicates that a user deleted a route
|
||||
RouteRemoved
|
||||
RouteRemoved Activity = 28
|
||||
// RouteUpdated indicates that a user updated a route
|
||||
RouteUpdated
|
||||
RouteUpdated Activity = 29
|
||||
// PeerSSHEnabled indicates that a user enabled SSH server on a peer
|
||||
PeerSSHEnabled
|
||||
PeerSSHEnabled Activity = 30
|
||||
// PeerSSHDisabled indicates that a user disabled SSH server on a peer
|
||||
PeerSSHDisabled
|
||||
PeerSSHDisabled Activity = 31
|
||||
// PeerRenamed indicates that a user renamed a peer
|
||||
PeerRenamed
|
||||
PeerRenamed Activity = 32
|
||||
// PeerLoginExpirationEnabled indicates that a user enabled login expiration of a peer
|
||||
PeerLoginExpirationEnabled
|
||||
PeerLoginExpirationEnabled Activity = 33
|
||||
// PeerLoginExpirationDisabled indicates that a user disabled login expiration of a peer
|
||||
PeerLoginExpirationDisabled
|
||||
PeerLoginExpirationDisabled Activity = 34
|
||||
// NameserverGroupCreated indicates that a user created a nameservers group
|
||||
NameserverGroupCreated
|
||||
NameserverGroupCreated Activity = 35
|
||||
// NameserverGroupDeleted indicates that a user deleted a nameservers group
|
||||
NameserverGroupDeleted
|
||||
NameserverGroupDeleted Activity = 36
|
||||
// NameserverGroupUpdated indicates that a user updated a nameservers group
|
||||
NameserverGroupUpdated
|
||||
NameserverGroupUpdated Activity = 37
|
||||
// AccountPeerLoginExpirationEnabled indicates that a user enabled peer login expiration for the account
|
||||
AccountPeerLoginExpirationEnabled
|
||||
AccountPeerLoginExpirationEnabled Activity = 38
|
||||
// AccountPeerLoginExpirationDisabled indicates that a user disabled peer login expiration for the account
|
||||
AccountPeerLoginExpirationDisabled
|
||||
AccountPeerLoginExpirationDisabled Activity = 39
|
||||
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
||||
AccountPeerLoginExpirationDurationUpdated
|
||||
AccountPeerLoginExpirationDurationUpdated Activity = 40
|
||||
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
||||
PersonalAccessTokenCreated
|
||||
PersonalAccessTokenCreated Activity = 41
|
||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||
PersonalAccessTokenDeleted
|
||||
PersonalAccessTokenDeleted Activity = 42
|
||||
// ServiceUserCreated indicates that a user created a service user
|
||||
ServiceUserCreated
|
||||
ServiceUserCreated Activity = 43
|
||||
// ServiceUserDeleted indicates that a user deleted a service user
|
||||
ServiceUserDeleted
|
||||
ServiceUserDeleted Activity = 44
|
||||
// UserBlocked indicates that a user blocked another user
|
||||
UserBlocked
|
||||
UserBlocked Activity = 45
|
||||
// UserUnblocked indicates that a user unblocked another user
|
||||
UserUnblocked
|
||||
UserUnblocked Activity = 46
|
||||
// UserDeleted indicates that a user deleted another user
|
||||
UserDeleted
|
||||
UserDeleted Activity = 47
|
||||
// GroupDeleted indicates that a user deleted group
|
||||
GroupDeleted
|
||||
GroupDeleted Activity = 48
|
||||
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
||||
UserLoggedInPeer
|
||||
UserLoggedInPeer Activity = 49
|
||||
// PeerLoginExpired indicates that the user peer login has been expired and peer disconnected
|
||||
PeerLoginExpired
|
||||
PeerLoginExpired Activity = 50
|
||||
// DashboardLogin indicates that the user logged in to the dashboard
|
||||
DashboardLogin
|
||||
DashboardLogin Activity = 51
|
||||
// IntegrationCreated indicates that the user created an integration
|
||||
IntegrationCreated
|
||||
IntegrationCreated Activity = 52
|
||||
// IntegrationUpdated indicates that the user updated an integration
|
||||
IntegrationUpdated
|
||||
IntegrationUpdated Activity = 53
|
||||
// IntegrationDeleted indicates that the user deleted an integration
|
||||
IntegrationDeleted
|
||||
IntegrationDeleted Activity = 54
|
||||
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
|
||||
AccountPeerApprovalEnabled
|
||||
AccountPeerApprovalEnabled Activity = 55
|
||||
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
|
||||
AccountPeerApprovalDisabled
|
||||
AccountPeerApprovalDisabled Activity = 56
|
||||
// PeerApproved indicates that the peer has been approved
|
||||
PeerApproved
|
||||
PeerApproved Activity = 57
|
||||
// PeerApprovalRevoked indicates that the peer approval has been revoked
|
||||
PeerApprovalRevoked
|
||||
PeerApprovalRevoked Activity = 58
|
||||
// TransferredOwnerRole indicates that the user transferred the owner role of the account
|
||||
TransferredOwnerRole
|
||||
TransferredOwnerRole Activity = 59
|
||||
// PostureCheckCreated indicates that the user created a posture check
|
||||
PostureCheckCreated
|
||||
PostureCheckCreated Activity = 60
|
||||
// PostureCheckUpdated indicates that the user updated a posture check
|
||||
PostureCheckUpdated
|
||||
PostureCheckUpdated Activity = 61
|
||||
// PostureCheckDeleted indicates that the user deleted a posture check
|
||||
PostureCheckDeleted
|
||||
PostureCheckDeleted Activity = 62
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
|
||||
@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
||||
|
||||
func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||
|
||||
sr := func(v string) *string { return &v }
|
||||
br := func(v bool) *bool { return &v }
|
||||
|
||||
@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
|
||||
Id: testDNSSettingsAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
testDNSSettingsUserID: server.NewAdminUser("test_user"),
|
||||
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
|
||||
},
|
||||
DNSSettings: baseExistingDNSSettings,
|
||||
}
|
||||
|
||||
@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
},
|
||||
}
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||
events := generateEvents(accountID, adminUser.Id)
|
||||
handler := initEventsTestData(accountID, adminUser, events...)
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
||||
return &GeolocationsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
user := server.NewAdminUser("test_user", "account_id")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Users: map[string]*server.User{
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
|
||||
Name: "Group",
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||
p := initGroupTestData(adminUser, group)
|
||||
|
||||
for _, tc := range tt {
|
||||
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||
p := initGroupTestData(adminUser)
|
||||
|
||||
for _, tc := range tt {
|
||||
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||
p := initGroupTestData(adminUser)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
@@ -38,7 +39,7 @@ type emptyObject struct {
|
||||
}
|
||||
|
||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
@@ -75,7 +76,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
AuthCfg: authCfg,
|
||||
}
|
||||
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
|
||||
Id: testNSGroupAccountID,
|
||||
Domain: "hotmail.com",
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
"test_user": server.NewAdminUser("test_user", "account_id"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
user := server.NewAdminUser("test_user", "account_id")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
|
||||
@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
user := server.NewAdminUser("test_user", "account_id")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
|
||||
@@ -62,7 +62,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
return accountPostureChecks, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
user := server.NewAdminUser("test_user", "account_id")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Users: map[string]*server.User{
|
||||
|
||||
@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
|
||||
},
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
"test_user": server.NewAdminUser("test_user", "account_id"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
defaultSetupKey := server.GenerateDefaultSetupKey()
|
||||
defaultSetupKey.Id = existingSetupKeyID
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
adminUser := server.NewAdminUser("test_user", "account_id")
|
||||
|
||||
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
|
||||
server.SetupKeyUnlimitedUsage, true)
|
||||
|
||||
@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
|
||||
case <-ticker.C:
|
||||
select {
|
||||
case <-cancel:
|
||||
log.Debugf("scheduled job %s was canceled, stop timer", ID)
|
||||
log.Tracef("scheduled job %s was canceled, stop timer", ID)
|
||||
ticker.Stop()
|
||||
return
|
||||
default:
|
||||
log.Debugf("time to do a scheduled job %s", ID)
|
||||
log.Tracef("time to do a scheduled job %s", ID)
|
||||
}
|
||||
runIn, reschedule := job()
|
||||
if !reschedule {
|
||||
wm.mu.Lock()
|
||||
defer wm.mu.Unlock()
|
||||
delete(wm.jobs, ID)
|
||||
log.Debugf("job %s is not scheduled to run again", ID)
|
||||
log.Tracef("job %s is not scheduled to run again", ID)
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
|
||||
ticker.Reset(runIn)
|
||||
}
|
||||
case <-cancel:
|
||||
log.Debugf("job %s was canceled, stopping timer", ID)
|
||||
log.Tracef("job %s was canceled, stopping timer", ID)
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
|
||||
// Get the reflect.Value of the records slice
|
||||
v := reflect.ValueOf(records)
|
||||
if v.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("provided input is not a slice")
|
||||
}
|
||||
|
||||
// Insert records in batches
|
||||
for i := 0; i < v.Len(); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > v.Len() {
|
||||
end = v.Len()
|
||||
}
|
||||
// Use reflect.Slice to get a slice of the records for the current batch
|
||||
batch := v.Slice(i, end).Interface()
|
||||
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||
start := time.Now()
|
||||
|
||||
for _, key := range account.SetupKeys {
|
||||
account.SetupKeysG = append(account.SetupKeysG, *key)
|
||||
// operate over a fresh copy as we will modify its fields
|
||||
accCopy := account.Copy()
|
||||
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
|
||||
for _, key := range accCopy.SetupKeys {
|
||||
//we need an explicit reference to the account for gorm
|
||||
key.AccountID = accCopy.Id
|
||||
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
|
||||
}
|
||||
|
||||
for id, peer := range account.Peers {
|
||||
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
|
||||
for id, peer := range accCopy.Peers {
|
||||
peer.ID = id
|
||||
account.PeersG = append(account.PeersG, *peer)
|
||||
//we need an explicit reference to the account for gorm
|
||||
peer.AccountID = accCopy.Id
|
||||
accCopy.PeersG = append(accCopy.PeersG, *peer)
|
||||
}
|
||||
|
||||
for id, user := range account.Users {
|
||||
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
|
||||
for id, user := range accCopy.Users {
|
||||
user.Id = id
|
||||
//we need an explicit reference to the account for gorm
|
||||
user.AccountID = accCopy.Id
|
||||
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
account.UsersG = append(account.UsersG, *user)
|
||||
accCopy.UsersG = append(accCopy.UsersG, *user)
|
||||
}
|
||||
|
||||
for id, group := range account.Groups {
|
||||
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
|
||||
for id, group := range accCopy.Groups {
|
||||
group.ID = id
|
||||
account.GroupsG = append(account.GroupsG, *group)
|
||||
//we need an explicit reference to the account for gorm
|
||||
group.AccountID = accCopy.Id
|
||||
accCopy.GroupsG = append(accCopy.GroupsG, *group)
|
||||
}
|
||||
|
||||
for id, route := range account.Routes {
|
||||
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
|
||||
for id, route := range accCopy.Routes {
|
||||
route.ID = id
|
||||
account.RoutesG = append(account.RoutesG, *route)
|
||||
//we need an explicit reference to the account for gorm
|
||||
route.AccountID = accCopy.Id
|
||||
accCopy.RoutesG = append(accCopy.RoutesG, *route)
|
||||
}
|
||||
|
||||
for id, ns := range account.NameServerGroups {
|
||||
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
|
||||
for id, ns := range accCopy.NameServerGroups {
|
||||
ns.ID = id
|
||||
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
|
||||
//we need an explicit reference to the account for gorm
|
||||
ns.AccountID = accCopy.Id
|
||||
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
|
||||
}
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
|
||||
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
result = tx.Select(clause.Associations).Delete(account)
|
||||
result = tx.Select(clause.Associations).Delete(accCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
result = tx.
|
||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
|
||||
Create(accCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
return nil
|
||||
|
||||
const batchSize = 500
|
||||
err := batchInsert(accCopy.PeersG, batchSize, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = batchInsert(accCopy.UsersG, batchSize, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = batchInsert(accCopy.GroupsG, batchSize, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = batchInsert(accCopy.RoutesG, batchSize, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
|
||||
})
|
||||
|
||||
took := time.Since(start)
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
|
||||
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||
start := time.Now()
|
||||
|
||||
account.UsersG = make([]User, 0, len(account.Users))
|
||||
for id, user := range account.Users {
|
||||
user.Id = id
|
||||
//we need an explicit reference to an account as it is missing for some reason
|
||||
user.AccountID = account.Id
|
||||
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
account.UsersG = append(account.UsersG, *user)
|
||||
}
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||
if result.Error != nil {
|
||||
|
||||
@@ -2,7 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
}
|
||||
}
|
||||
func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
store := newSqliteStore(t)
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
groupALL, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
const numPerAccount = 2000
|
||||
for n := 0; n < numPerAccount; n++ {
|
||||
netIP := randomIPv4()
|
||||
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
Key: peerID,
|
||||
SetupKey: "",
|
||||
IP: netIP,
|
||||
Name: peerID,
|
||||
DNSLabel: peerID,
|
||||
UserID: userID,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
account.Peers[peerID] = peer
|
||||
group, _ := account.GetGroupAll()
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
user := &User{
|
||||
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Users[user.Id] = user
|
||||
route := &route2.Route{
|
||||
ID: fmt.Sprintf("network-id-%d", n),
|
||||
Description: "base route",
|
||||
NetID: fmt.Sprintf("network-id-%d", n),
|
||||
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
||||
NetworkType: route2.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
Groups: []string{groupALL.ID},
|
||||
}
|
||||
account.Routes[route.ID] = route
|
||||
|
||||
group = &nbgroup.Group{
|
||||
ID: fmt.Sprintf("group-id-%d", n),
|
||||
AccountID: account.Id,
|
||||
Name: fmt.Sprintf("group-id-%d", n),
|
||||
Issued: "api",
|
||||
Peers: nil,
|
||||
}
|
||||
account.Groups[group.ID] = group
|
||||
|
||||
nameserver := &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("nameserver-id-%d", n),
|
||||
AccountID: account.Id,
|
||||
Name: fmt.Sprintf("nameserver-id-%d", n),
|
||||
Description: "",
|
||||
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
|
||||
Groups: []string{group.ID},
|
||||
Primary: false,
|
||||
Domains: nil,
|
||||
Enabled: false,
|
||||
SearchDomainsEnabled: false,
|
||||
}
|
||||
account.NameServerGroups[nameserver.ID] = nameserver
|
||||
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 1 {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies) != 1 {
|
||||
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.Peers) != numPerAccount {
|
||||
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
|
||||
numPerAccount, len(a.Peers))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.Users) != numPerAccount+1 {
|
||||
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
|
||||
numPerAccount+1, len(a.Users))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.Routes) != numPerAccount {
|
||||
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
|
||||
numPerAccount, len(a.Routes))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.NameServerGroups) != numPerAccount {
|
||||
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
|
||||
numPerAccount, len(a.NameServerGroups))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.NameServerGroups) != numPerAccount {
|
||||
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
|
||||
numPerAccount, len(a.NameServerGroups))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
|
||||
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
|
||||
numPerAccount+1, len(a.SetupKeys))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlite_SaveAccount(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
admin := account.Users["testuser"]
|
||||
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
HashedToken: "hashed token",
|
||||
}}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
store := newSqliteStore(t)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
user := NewAdminUser(testUserID, "account_id")
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
|
||||
|
||||
return store.SaveAccount(account)
|
||||
}
|
||||
|
||||
func randomIPv4() net.IP {
|
||||
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, 4)
|
||||
for i := range b {
|
||||
b[i] = byte(rand.Intn(256))
|
||||
}
|
||||
return net.IP(b)
|
||||
}
|
||||
|
||||
@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
|
||||
accountID string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Id: ID,
|
||||
AccountID: accountID,
|
||||
Role: role,
|
||||
IsServiceUser: isServiceUser,
|
||||
NonDeletable: nonDeletable,
|
||||
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewRegularUser(ID, accountID string) *User {
|
||||
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
|
||||
accountID)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewAdminUser(ID, accountID string) *User {
|
||||
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
|
||||
accountID)
|
||||
}
|
||||
|
||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||
func NewOwnerUser(id string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewOwnerUser(ID, accountID string) *User {
|
||||
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
|
||||
accountID)
|
||||
}
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
|
||||
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
|
||||
}
|
||||
|
||||
newUserID := uuid.New().String()
|
||||
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
|
||||
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
|
||||
log.Debugf("New User: %v", newUser)
|
||||
account.Users[newUserID] = newUser
|
||||
|
||||
|
||||
@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
delete(account.Users, mockUserID)
|
||||
|
||||
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
|
||||
func TestUser_IsAdmin(t *testing.T) {
|
||||
|
||||
user := NewAdminUser(mockUserID)
|
||||
user := NewAdminUser(mockUserID, mockAccountID)
|
||||
assert.True(t, user.HasAdminPower())
|
||||
|
||||
user = NewRegularUser(mockUserID)
|
||||
user = NewRegularUser(mockUserID, mockAccountID)
|
||||
assert.False(t, user.HasAdminPower())
|
||||
}
|
||||
|
||||
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
}
|
||||
|
||||
// create other users
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||
account.Users[adminUserID] = NewAdminUser(adminUserID)
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
|
||||
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
|
||||
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
||||
err = manager.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
|
||||
@@ -56,7 +56,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
|
||||
|
||||
// MarshalCredential marshal a Credential instance and returns a Message object
|
||||
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
|
||||
rosenpassPubKey []byte, rosenpassAddr, relayedAddress, serverRefIP string) (*proto.Message, error) {
|
||||
rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
|
||||
return &proto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey.String(),
|
||||
@@ -69,10 +69,6 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassServerAddr: rosenpassAddr,
|
||||
},
|
||||
Relay: &proto.Relay{
|
||||
RelayedAddress: relayedAddress,
|
||||
SrvRefAddress: serverRefIP,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -215,21 +215,16 @@ type Body struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
|
||||
// these will be set in OFFER, ANSWER, CANDIDATE only
|
||||
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
|
||||
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
// wgListenPort is an actual WireGuard listen port
|
||||
// these will be set in OFFER, ANSWER, CANDIDATE only
|
||||
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
|
||||
// these will be set in OFFER, ANSWER, CANDIDATE only
|
||||
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
|
||||
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
|
||||
Mode *Mode `protobuf:"bytes,5,opt,name=mode,proto3" json:"mode,omitempty"`
|
||||
// featuresSupported list of supported features by the client of this protocol
|
||||
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
|
||||
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
||||
// is this optional or mandatory?
|
||||
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
|
||||
Relay *Relay `protobuf:"bytes,8,opt,name=relay,proto3" json:"relay,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Body) Reset() {
|
||||
@@ -313,18 +308,13 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Body) GetRelay() *Relay {
|
||||
if x != nil {
|
||||
return x.Relay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mode indicates a connection mode
|
||||
type Mode struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Direct *bool `protobuf:"varint,1,opt,name=direct,proto3,oneof" json:"direct,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Mode) Reset() {
|
||||
@@ -359,59 +349,11 @@ func (*Mode) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
type Relay struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
RelayedAddress string `protobuf:"bytes,1,opt,name=relayedAddress,proto3" json:"relayedAddress,omitempty"`
|
||||
SrvRefAddress string `protobuf:"bytes,2,opt,name=srvRefAddress,proto3" json:"srvRefAddress,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Relay) Reset() {
|
||||
*x = Relay{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
func (x *Mode) GetDirect() bool {
|
||||
if x != nil && x.Direct != nil {
|
||||
return *x.Direct
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Relay) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Relay) ProtoMessage() {}
|
||||
|
||||
func (x *Relay) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Relay.ProtoReflect.Descriptor instead.
|
||||
func (*Relay) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *Relay) GetRelayedAddress() string {
|
||||
if x != nil {
|
||||
return x.RelayedAddress
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Relay) GetSrvRefAddress() string {
|
||||
if x != nil {
|
||||
return x.SrvRefAddress
|
||||
}
|
||||
return ""
|
||||
return false
|
||||
}
|
||||
|
||||
type RosenpassConfig struct {
|
||||
@@ -427,7 +369,7 @@ type RosenpassConfig struct {
|
||||
func (x *RosenpassConfig) Reset() {
|
||||
*x = RosenpassConfig{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[5]
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -440,7 +382,7 @@ func (x *RosenpassConfig) String() string {
|
||||
func (*RosenpassConfig) ProtoMessage() {}
|
||||
|
||||
func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[5]
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -453,7 +395,7 @@ func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RosenpassConfig.ProtoReflect.Descriptor instead.
|
||||
func (*RosenpassConfig) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{5}
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *RosenpassConfig) GetRosenpassPubKey() []byte {
|
||||
@@ -489,7 +431,7 @@ var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
|
||||
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
|
||||
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
|
||||
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
|
||||
@@ -509,39 +451,33 @@ var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
|
||||
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2b, 0x0a, 0x05, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x18,
|
||||
0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x52, 0x05, 0x72, 0x65,
|
||||
0x6c, 0x61, 0x79, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f,
|
||||
0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52,
|
||||
0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10,
|
||||
0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x06, 0x0a, 0x04, 0x4d,
|
||||
0x6f, 0x64, 0x65, 0x22, 0x55, 0x0a, 0x05, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x12, 0x26, 0x0a, 0x0e,
|
||||
0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x41, 0x64, 0x64,
|
||||
0x72, 0x65, 0x73, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x72, 0x76, 0x52, 0x65, 0x66, 0x41, 0x64,
|
||||
0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x73, 0x72, 0x76,
|
||||
0x52, 0x65, 0x66, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x6f,
|
||||
0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x0a,
|
||||
0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
|
||||
0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e,
|
||||
0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x69,
|
||||
0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x04,
|
||||
0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
|
||||
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d,
|
||||
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
|
||||
0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
|
||||
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x6f,
|
||||
0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x69,
|
||||
0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63,
|
||||
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e,
|
||||
0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
|
||||
0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22,
|
||||
0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
|
||||
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
|
||||
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
|
||||
0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74,
|
||||
0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d,
|
||||
0x0a, 0x0f, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69,
|
||||
0x67, 0x12, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75,
|
||||
0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65,
|
||||
0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72,
|
||||
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
|
||||
0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
|
||||
0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01,
|
||||
0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
||||
0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
|
||||
0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
|
||||
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67,
|
||||
0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72,
|
||||
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59,
|
||||
0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12,
|
||||
0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
||||
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
|
||||
0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -557,31 +493,29 @@ func file_signalexchange_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_signalexchange_proto_goTypes = []interface{}{
|
||||
(Body_Type)(0), // 0: signalexchange.Body.Type
|
||||
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
|
||||
(*Message)(nil), // 2: signalexchange.Message
|
||||
(*Body)(nil), // 3: signalexchange.Body
|
||||
(*Mode)(nil), // 4: signalexchange.Mode
|
||||
(*Relay)(nil), // 5: signalexchange.Relay
|
||||
(*RosenpassConfig)(nil), // 6: signalexchange.RosenpassConfig
|
||||
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
|
||||
}
|
||||
var file_signalexchange_proto_depIdxs = []int32{
|
||||
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
|
||||
0, // 1: signalexchange.Body.type:type_name -> signalexchange.Body.Type
|
||||
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
|
||||
6, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
|
||||
5, // 4: signalexchange.Body.relay:type_name -> signalexchange.Relay
|
||||
1, // 5: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
|
||||
1, // 6: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
|
||||
1, // 7: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
|
||||
1, // 8: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
|
||||
7, // [7:9] is the sub-list for method output_type
|
||||
5, // [5:7] is the sub-list for method input_type
|
||||
5, // [5:5] is the sub-list for extension type_name
|
||||
5, // [5:5] is the sub-list for extension extendee
|
||||
0, // [0:5] is the sub-list for field type_name
|
||||
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
|
||||
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
|
||||
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
|
||||
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
|
||||
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
|
||||
6, // [6:8] is the sub-list for method output_type
|
||||
4, // [4:6] is the sub-list for method input_type
|
||||
4, // [4:4] is the sub-list for extension type_name
|
||||
4, // [4:4] is the sub-list for extension extendee
|
||||
0, // [0:4] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_signalexchange_proto_init() }
|
||||
@@ -639,18 +573,6 @@ func file_signalexchange_proto_init() {
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Relay); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*RosenpassConfig); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -663,13 +585,14 @@ func file_signalexchange_proto_init() {
|
||||
}
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_signalexchange_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 6,
|
||||
NumMessages: 5,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -49,33 +49,22 @@ message Body {
|
||||
MODE = 4;
|
||||
}
|
||||
Type type = 1;
|
||||
// these will be set in OFFER, ANSWER, CANDIDATE only
|
||||
string payload = 2;
|
||||
// wgListenPort is an actual WireGuard listen port
|
||||
// these will be set in OFFER, ANSWER, CANDIDATE only
|
||||
uint32 wgListenPort = 3;
|
||||
// these will be set in OFFER, ANSWER, CANDIDATE only
|
||||
string netBirdVersion = 4;
|
||||
|
||||
Mode mode = 5;
|
||||
|
||||
// featuresSupported list of supported features by the client of this protocol
|
||||
repeated uint32 featuresSupported = 6;
|
||||
|
||||
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
||||
// is this optional or mandatory?
|
||||
RosenpassConfig rosenpassConfig = 7;
|
||||
|
||||
Relay relay = 8;
|
||||
}
|
||||
|
||||
// Mode indicates a connection mode
|
||||
message Mode {
|
||||
}
|
||||
|
||||
message Relay {
|
||||
string relayedAddress = 1;
|
||||
string srvRefAddress = 2;
|
||||
optional bool direct = 1;
|
||||
}
|
||||
|
||||
message RosenpassConfig {
|
||||
|
||||
@@ -49,6 +49,10 @@ func RemoveDialerHooks() {
|
||||
|
||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
var resolver *net.Resolver
|
||||
if d.Resolver != nil {
|
||||
resolver = d.Resolver
|
||||
@@ -123,6 +127,10 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
||||
}
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
@@ -143,6 +151,10 @@ func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -52,6 +53,10 @@ func RemoveListenerHooks() {
|
||||
// ListenPacket listens on the network address and returns a PacketConn
|
||||
// which includes support for write hooks.
|
||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
}
|
||||
|
||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen packet: %w", err)
|
||||
@@ -144,7 +149,11 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
package net
|
||||
|
||||
import "github.com/google/uuid"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
||||
NetbirdFwmark = 0x1BD00
|
||||
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
)
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
@@ -15,3 +21,7 @@ type ConnectionID string
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
func CustomRoutingDisabled() bool {
|
||||
return os.Getenv(envDisableCustomRouting) == "true"
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
var setErr error
|
||||
|
||||
err := conn.Control(func(fd uintptr) {
|
||||
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||
setErr = SetSocketOpt(int(fd))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
@@ -33,3 +33,11 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetSocketOpt(fd int) error {
|
||||
if CustomRoutingDisabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user