Compare commits

...

123 Commits

Author SHA1 Message Date
Zoltán Papp
cb77ff4661 Fix relay instance address indication 2024-07-10 22:33:15 +02:00
Zoltán Papp
83b83ccfd2 Add relay server address for the status 2024-07-10 22:17:54 +02:00
Zoltán Papp
4e75e15ea1 Add relay address to signal OFFER 2024-07-10 18:39:24 +02:00
Zoltán Papp
06afe64aff Fix deadlock 2024-07-10 18:34:04 +02:00
Zoltán Papp
7acaef1152 Try to fix wgproxy reference 2024-07-10 16:51:38 +02:00
Zoltán Papp
469be3442d Remove hardcoded debug lines 2024-07-10 14:17:50 +02:00
Zoltán Papp
d1b6387803 Fix token sending 2024-07-10 13:21:50 +02:00
Zoltán Papp
820e2feec9 Move to relay address config to object
Add test for mgm config parser
2024-07-10 11:30:02 +02:00
Zoltán Papp
e0d086a8a8 Implement dummy RemoteAddr on client conn 2024-07-10 10:12:49 +02:00
Zoltán Papp
1f95467b02 Sonar fix 2024-07-09 16:51:40 +02:00
Zoltán Papp
6553d8ce03 Sonar fix 2024-07-09 16:50:29 +02:00
Zoltán Papp
f0c829afac Sonar fix 2024-07-09 16:48:50 +02:00
Zoltán Papp
86f14523e4 Add comment 2024-07-09 16:46:43 +02:00
Zoltán Papp
6cefcbfe5d Add comment 2024-07-09 16:44:12 +02:00
Zoltán Papp
19103031ee Optimisation for sonar 2024-07-09 16:38:50 +02:00
Zoltán Papp
7369f4bc38 Optimisation for sonar 2024-07-09 16:29:38 +02:00
Zoltán Papp
d9d275a7ce Optimisation for sonar 2024-07-09 16:27:20 +02:00
Zoltán Papp
57b85f4f8d Optimisation for sonar 2024-07-09 16:15:25 +02:00
Zoltán Papp
7ef191903e Fix logging in handshaker 2024-07-09 16:06:36 +02:00
Zoltan Papp
3bd15dd1c4 Fix moc interface 2024-07-09 10:34:13 +02:00
Zoltan Papp
1065e0a6c5 Fix moc interface 2024-07-09 10:22:38 +02:00
Zoltan Papp
d4ff55e6fe Fix typo 2024-07-09 10:09:09 +02:00
Zoltan Papp
5625d83c3f Fix lint 2024-07-09 09:44:23 +02:00
Zoltan Papp
63f2f51614 Fix typo 2024-07-08 23:14:09 +02:00
Zoltan Papp
defdcb631e Add sleep time for tests 2024-07-08 22:42:30 +02:00
Zoltan Papp
7bf0d04bed Remove unused function 2024-07-08 22:19:18 +02:00
Zoltan Papp
e4ec1fd757 Add sleep time after server started 2024-07-08 22:13:31 +02:00
Zoltan Papp
dab50f35d7 Fix ipv6 issue on tests 2024-07-08 21:56:15 +02:00
Zoltan Papp
2d7e797e08 Fix body close 2024-07-08 21:55:03 +02:00
Zoltan Papp
c3e8187a47 Fix lint issues 2024-07-08 21:53:20 +02:00
Zoltan Papp
cfac8c4762 fix test timing 2024-07-08 21:34:39 +02:00
Zoltan Papp
d9dfae625b Fix manager_test 2024-07-08 21:18:19 +02:00
Zoltán Papp
a9e6742d9a - Remove heartbeat logs
- Fix relay client tests
- Fix auth ID unmarshalling
- Add magic header check
2024-07-08 17:55:48 +02:00
Zoltán Papp
931f165c9a Remove garbage 2024-07-08 17:38:23 +02:00
Zoltán Papp
2803e1669b Remove meaningless tests 2024-07-08 17:24:49 +02:00
Zoltán Papp
f28a657a1d extend wginterface func with windows related things 2024-07-08 17:08:54 +02:00
Zoltán Papp
1f949f8cee Fix parameters of tests 2024-07-08 17:01:11 +02:00
Zoltán Papp
75f5b75bc4 Mock wginterface 2024-07-08 16:15:04 +02:00
Zoltán Papp
48a2f6e69d Mock wginterface 2024-07-08 16:12:08 +02:00
Zoltan Papp
b3715b5fad - Revert typos in turnCfg string
- merge main
2024-07-08 15:05:29 +02:00
Zoltan Papp
836072098b Integrate the relay authentication 2024-07-05 16:12:30 +02:00
Zoltan Papp
8845e8fbc7 replace bson to gob 2024-07-04 13:42:27 +02:00
Zoltan Papp
1a5ee744a8 - add file based cert
- print out the exposed address
- handle empty exposed address
2024-07-03 15:03:57 +02:00
Zoltan Papp
15a7b7629b Add exposed address 2024-07-02 11:57:17 +02:00
Zoltán Papp
d3785dc1fa Fix ssl configuration 2024-07-01 11:50:18 +02:00
Zoltán Papp
ed82ef7fe4 Fix error logging 2024-06-30 10:43:12 +02:00
Zoltán Papp
aa55fba5ee Add client side heartbeat handling 2024-06-29 14:13:05 +02:00
Zoltán Papp
faeae52329 Support exit node in ws client 2024-06-28 11:44:50 +02:00
Zoltán Papp
9ae03046e7 rename file 2024-06-28 11:17:21 +02:00
Zoltán Papp
98aa830831 Rename client ws package 2024-06-28 11:17:06 +02:00
Zoltán Papp
c94c949173 Add comment 2024-06-28 11:12:53 +02:00
Zoltán Papp
183f746158 Order the source code 2024-06-27 18:42:40 +02:00
Zoltán Papp
dd0d15c9d4 Add healthcheck code 2024-06-27 18:40:12 +02:00
Zoltán Papp
4d0e16f2d0 - Remove WaitForExitAcceptedConns logic from server
- Implement thread safe gracefully close logic
- organise the server code
2024-06-27 02:36:44 +02:00
Zoltán Papp
3fcdb51376 Error handling 2024-06-26 16:23:50 +02:00
Zoltán Papp
c0efce6556 Fix msg delivery timeouts 2024-06-26 16:22:26 +02:00
Zoltán Papp
f0eb004582 Single thread on server sending 2024-06-26 15:26:19 +02:00
Zoltán Papp
0a59f12012 Env var to force relay usage 2024-06-26 15:25:32 +02:00
Zoltán Papp
745e4f76b1 Remove gorilla lib 2024-06-26 15:25:01 +02:00
Zoltán Papp
085d072b17 - Add sha prefix for peer id in protocol
- Add magic cookie in hello msg
- Add tests
2024-06-25 17:36:04 +02:00
Zoltán Papp
0a67f5be1a Fix logic 2024-06-25 15:13:08 +02:00
Zoltán Papp
f72e852ccb Remove duplicated code 2024-06-24 18:54:03 +02:00
Zoltán Papp
54dc78aab8 Remove debug log 2024-06-24 15:30:25 +02:00
Zoltán Papp
69d8d5aa86 Fix the active conn type logic 2024-06-21 19:13:41 +02:00
Zoltán Papp
7581bbd925 Handle on offer listener in handshaker 2024-06-21 15:35:15 +02:00
Zoltán Papp
4d67d72785 Use permanent credentials 2024-06-21 15:02:54 +02:00
Zoltán Papp
4a08f1a1e9 Refactor handshaker loop 2024-06-21 12:35:28 +02:00
Zoltán Papp
bfe60c01ba Close proxy reading in case of eof 2024-06-21 00:55:30 +02:00
Zoltán Papp
06ceac65de - Fix reconnect guard
- Avoid double client creation
2024-06-21 00:55:07 +02:00
Zoltán Papp
6801dcb3f6 Fallback to relay conn 2024-06-20 18:17:30 +02:00
Zoltán Papp
c7db2c0524 Moc signal message support 2024-06-19 18:40:49 +02:00
Zoltán Papp
4f890ff712 Typo fix 2024-06-19 18:17:52 +02:00
Zoltán Papp
f7e6aa9b8f Change logging logic 2024-06-19 18:16:43 +02:00
Zoltán Papp
81f2330d49 Fix remote address in ws client 2024-06-19 18:16:23 +02:00
Zoltán Papp
0261e15aad Extend the cmd with argument handling
- add cobra to relay server
- add logger instance for handshaker
2024-06-19 17:40:16 +02:00
Zoltán Papp
11de2ec42e Fix open connection 2024-06-19 12:18:58 +02:00
Zoltán Papp
4d2a25b728 Code cleaning 2024-06-19 11:53:21 +02:00
Zoltán Papp
2f32e0d8cf Fix chicken-egg problem in the ice agent creation 2024-06-19 11:28:01 +02:00
Zoltán Papp
48310ef99c Fix engine test 2024-06-19 09:59:01 +02:00
Zoltán Papp
24f71bc68a Fix and extend test 2024-06-19 09:40:43 +02:00
Zoltán Papp
e26e2c3a75 Add conn status handling and protect agent 2024-06-18 17:40:37 +02:00
Zoltán Papp
a5e664d83d Code cleaning 2024-06-18 11:27:18 +02:00
Zoltán Papp
d8ab3c1632 Call peer.Open from engine 2024-06-18 11:23:39 +02:00
Zoltán Papp
63b4041e9c Rename connector to worker 2024-06-18 11:22:40 +02:00
Zoltán Papp
f7d8d03e55 Fix timers 2024-06-18 11:20:01 +02:00
Zoltán Papp
5b86a7f3f2 Fix relay mode evaulation 2024-06-18 11:10:55 +02:00
Zoltán Papp
deb8203f06 fix circle import 2024-06-17 18:02:52 +02:00
Zoltán Papp
e407fe02c5 Separate lifecircle of handshake, ice, relay connections
- fix Stun, Turn address update thread safety issue
- move conn worker login into peer package
2024-06-17 17:52:22 +02:00
Zoltán Papp
a7760bf0a7 Configurable relay address with env variable 2024-06-14 15:43:18 +02:00
Zoltan Papp
64f949abbb Integrate relay into peer conn
- extend mgm with relay address
- extend signaling with remote peer's relay address
- start setup relay connection before engine start
2024-06-14 14:40:31 +02:00
Zoltan Papp
38f2a59d1b Add comment 2024-06-12 10:56:21 +02:00
Zoltan Papp
9504012920 Set the proper buffer size in the client code 2024-06-09 21:10:57 +02:00
Zoltan Papp
5e93d117cf Use buf pool
- eliminate reader function generation
- fix write to closed channel panic
2024-06-09 20:33:35 +02:00
Zoltan Papp
8c70b7d7ff Replace ws lib on client side 2024-06-09 12:41:52 +02:00
Zoltan Papp
ed8def4d9b Protect ws writing in Gorilla ws 2024-06-07 16:07:35 +02:00
Zoltán Papp
1e115e3893 Merge branch 'main' into feature/relay 2024-06-06 13:38:40 +02:00
Zoltan Papp
fed9e587af Add close message type 2024-06-05 19:49:30 +02:00
Zoltan Papp
a40d4d2f32 - add comments
- avoid double closing messages
- add cleanup routine for relay manager
2024-06-04 14:40:35 +02:00
Zoltán Papp
15818b72c6 Add alternative ws server implementation 2024-06-03 21:38:37 +02:00
Zoltán Papp
0556dc1860 Avoid nil pointer exception in test in case of err 2024-06-03 21:36:46 +02:00
Zoltán Papp
2b369cd28f Add quic transporter 2024-06-03 20:17:43 +02:00
Zoltán Papp
9d44a476c6 Fix double unlock in client.go 2024-06-03 20:14:39 +02:00
Zoltán Papp
57ddb5f262 Add comment 2024-06-03 11:22:16 +02:00
Zoltan Papp
4ced07dd8d Fix close conn threading issue 2024-06-03 01:37:56 +02:00
Zoltán Papp
3430b81622 Add relay server tracking 2024-06-01 11:48:15 +02:00
Zoltán Papp
fd4ad15c83 Move reconnection logic to separated struct 2024-06-01 11:25:00 +02:00
Zoltán Papp
4ff069a102 Support multiple server 2024-05-29 16:40:26 +02:00
Zoltán Papp
7cc3964a4d Use mux for http server
Without it can not start multiple http
server instances for unit tests
2024-05-29 16:11:58 +02:00
Zoltan Papp
6d627f1923 Code cleaning 2024-05-28 01:27:53 +02:00
Zoltan Papp
076ce69a24 Add reconnect logic 2024-05-28 01:00:25 +02:00
Zoltán Papp
645a1f31a7 Fix writing/reading to a closed conn 2024-05-27 10:25:08 +02:00
Zoltán Papp
b4aa7e50f9 Close sockets on server cmd 2024-05-27 09:42:27 +02:00
Zoltán Papp
173ca25dac Fix in client the close event 2024-05-26 22:14:33 +02:00
Zoltán Papp
36b2cd16cc Remove channel binding logic 2024-05-23 13:24:02 +02:00
Zoltán Papp
0a05f8b4d4 Use buffer pool and protect exported functions 2024-05-22 00:38:41 +02:00
Zoltán Papp
e82c0a55a3 Set to blocking the message queue 2024-05-21 16:21:29 +02:00
Zoltán Papp
13eb457132 Add registration response message to the communication 2024-05-21 15:51:37 +02:00
Zoltan Papp
1c9c9ae47e Remove sync.pool 2024-05-20 11:38:23 +02:00
Zoltan Papp
9ac5a1ed3f Add udp listener and did some change for debug purpose. 2024-05-19 12:41:06 +02:00
Zoltan Papp
d4eaec5cbd Followup messages modification 2024-05-17 23:41:47 +02:00
Zoltan Papp
6ae7a790f2 Fix buffer handling 2024-05-17 23:29:47 +02:00
Zoltan Papp
49dfbc82d9 Add relay cmd 2024-05-17 20:24:06 +02:00
Zoltan Papp
57a89cf0cc Add initial relay code 2024-05-17 17:43:28 +02:00
91 changed files with 6675 additions and 1498 deletions

View File

@@ -92,7 +92,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)

View File

@@ -163,7 +163,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
}

View File

@@ -26,6 +26,8 @@ import (
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -244,6 +246,20 @@ func (c *ConnectClient) run(
c.statusRecorder.MarkSignalConnected()
relayURL, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURL, myPrivateKey.PublicKey().String())
if relayURL != "" {
if token != nil {
relayManager.UpdateToken(token)
}
log.Infof("connecting to the Relay service %s", relayURL)
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
}
c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
@@ -255,7 +271,7 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
c.engineMutex.Unlock()
err = c.engine.Start()
@@ -299,6 +315,30 @@ func (c *ConnectClient) run(
return nil
}
func parseRelayInfo(resp *mgmProto.LoginResponse) (string, *hmac.Token) {
// todo remove this
if ra := peer.ForcedRelayAddress(); ra != "" {
return ra, nil
}
msg := resp.GetWiretrusteeConfig().GetRelay()
if msg == nil {
return "", nil
}
var url string
if msg.GetUrls() != nil && len(msg.GetUrls()) > 0 {
url = msg.GetUrls()[0]
}
token := &hmac.Token{
Payload: msg.GetTokenPayload(),
Signature: msg.GetTokenSignature(),
}
return url, token
}
func (c *ConnectClient) Engine() *Engine {
var e *Engine
c.engineMutex.Lock()

View File

@@ -13,6 +13,7 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pion/ice/v3"
@@ -24,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay"
@@ -39,6 +41,8 @@ import (
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
@@ -101,7 +105,8 @@ type EngineConfig struct {
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
signal signal.Client
signal signal.Client
signaler *peer.Signaler
// mgmClient is a Management Service client
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
@@ -122,7 +127,8 @@ type Engine struct {
// STUNs is a list of STUN servers used by ICE
STUNs []*stun.URI
// TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI
TURNs []*stun.URI
StunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
@@ -134,7 +140,7 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
wgInterface *iface.WGIface
wgInterface iface.IWGIface
wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault
@@ -160,10 +166,10 @@ type Engine struct {
relayProbe *Probe
wgProbe *Probe
wgConnWorker sync.WaitGroup
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
}
// Peer is an instance of the Connection Peer
@@ -178,6 +184,7 @@ func NewEngine(
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
@@ -188,6 +195,7 @@ func NewEngine(
clientCancel,
signalClient,
mgmClient,
relayManager,
config,
mobileDep,
statusRecorder,
@@ -205,6 +213,7 @@ func NewEngineWithProbes(
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
@@ -214,12 +223,13 @@ func NewEngineWithProbes(
wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
relayManager: relayManager,
peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{},
config: config,
@@ -265,7 +275,6 @@ func (e *Engine) Stop() error {
time.Sleep(500 * time.Millisecond)
e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil
}
@@ -465,80 +474,42 @@ func (e *Engine) removePeer(peerKey string) error {
conn, exists := e.peerConns[peerKey]
if exists {
delete(e.peerConns, peerKey)
err := conn.Close()
if err != nil {
switch err.(type) {
case *peer.ConnectionAlreadyClosedError:
return nil
default:
return err
}
}
conn.Close()
}
return nil
}
func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
err := s.Send(&sProto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
Body: &sProto.Body{
Type: sProto.Body_CANDIDATE,
Payload: candidate.Marshal(),
},
})
if err != nil {
return err
}
return nil
}
func sendSignal(message *sProto.Message, s signal.Client) error {
return s.Send(message)
}
// SignalOfferAnswer signals either an offer or an answer to remote peer
func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client,
isAnswer bool) error {
var t sProto.Body_Type
if isAnswer {
t = sProto.Body_ANSWER
} else {
t = sProto.Body_OFFER
}
msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
}, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
if err != nil {
return err
}
err = s.Send(msg)
if err != nil {
return err
}
return nil
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if update.GetWiretrusteeConfig() != nil {
err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
wCfg := update.GetWiretrusteeConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return err
}
err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return err
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.StunTurn.Store(stunTurn)
relayMsg := wCfg.GetRelay()
if relayMsg != nil {
c := &auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
}
e.relayManager.UpdateToken(c)
}
// todo update relay address in the relay manager
// todo update signal
}
@@ -934,68 +905,13 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
e.wgConnWorker.Add(1)
go e.connWorker(conn, peerKey)
conn.Open()
}
return nil
}
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
defer e.wgConnWorker.Done()
for {
// randomize starting time a bit
min := 500
max := 2000
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond
select {
case <-e.ctx.Done():
return
case <-time.After(duration):
}
// if peer has been removed -> give up
if !e.peerExists(peerKey) {
log.Debugf("peer %s doesn't exist anymore, won't retry connection", peerKey)
return
}
if !e.signal.Ready() {
log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey)
continue
}
// we might have received new STUN and TURN servers meanwhile, so update them
e.syncMsgMux.Lock()
conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
e.syncMsgMux.Unlock()
err := conn.Open(e.ctx)
if err != nil {
log.Debugf("connection to peer %s failed: %v", peerKey, err)
var connectionClosedError *peer.ConnectionClosedError
switch {
case errors.As(err, &connectionClosedError):
// conn has been forced to close, so we exit the loop
return
default:
}
}
}
}
func (e *Engine) peerExists(peerKey string) bool {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
_, ok := e.peerConns[peerKey]
return ok
}
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
wgConfig := peer.WgConfig{
RemoteKey: pubKey,
@@ -1028,52 +944,29 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
StunTurn: stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
Timeout: timeout,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
ICEConfig: peer.ICEConfig{
StunTurn: e.StunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
}
peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
if err != nil {
return nil, err
}
wgPubKey, err := wgtypes.ParseKey(pubKey)
if err != nil {
return nil, err
}
signalOffer := func(offerAnswer peer.OfferAnswer) error {
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
}
signalCandidate := func(candidate ice.Candidate) error {
return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
}
signalAnswer := func(offerAnswer peer.OfferAnswer) error {
return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
}
peerConn.SetSignalCandidate(signalCandidate)
peerConn.SetSignalOffer(signalOffer)
peerConn.SetSignalAnswer(signalAnswer)
peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
return sendSignal(message, e.signal)
})
if e.rpManager != nil {
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
}
@@ -1116,6 +1009,7 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1138,6 +1032,7 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1146,7 +1041,7 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE:
}
@@ -1442,7 +1337,7 @@ func (e *Engine) receiveProbeEvents() {
for _, peer := range e.peerConns {
key := peer.GetKey()
wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key)
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
}

View File

@@ -36,6 +36,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto"
@@ -57,6 +58,12 @@ var (
}
)
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
os.Exit(code)
}
func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
@@ -72,13 +79,23 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -207,20 +224,28 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
wgIface := &iface.MockWGIface{
RemovePeerFunc: func(peerKey string) error {
return nil
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -403,8 +428,8 @@ func TestEngine_Sync(t *testing.T) {
}
return nil
}
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
@@ -563,7 +588,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -733,7 +759,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -1009,7 +1036,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
relayMgr := relayClient.NewManager(ctx, "", key.PublicKey().String())
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}
@@ -1073,7 +1101,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err

File diff suppressed because it is too large Load Diff

View File

@@ -2,25 +2,33 @@ package peer
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/magiconair/properties/assert"
"github.com/pion/stun/v2"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
var connConf = ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
StunTurn: []*stun.URI{},
InterfaceBlackList: nil,
Timeout: time.Second,
LocalWgPort: 51820,
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
Timeout: time.Second,
LocalWgPort: 51820,
ICEConfig: ICEConfig{
InterfaceBlackList: nil,
},
}
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestNewConn_interfaceFilter(t *testing.T) {
@@ -40,7 +48,7 @@ func TestConn_GetKey(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -55,7 +63,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -63,7 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
<-conn.remoteOffersCh
<-conn.handshaker.remoteOffersCh
wg.Done()
}()
@@ -92,7 +100,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -100,7 +108,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
<-conn.remoteAnswerCh
<-conn.handshaker.remoteAnswerCh
wg.Done()
}()
@@ -128,58 +136,33 @@ func TestConn_Status(t *testing.T) {
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
tables := []struct {
name string
status ConnStatus
want ConnStatus
name string
statusIce ConnStatus
statusRelay ConnStatus
want ConnStatus
}{
{"StatusConnected", StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting},
{"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
{"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
{"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
{"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
{"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
{"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
{"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
conn.status = table.status
conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
func TestConn_Close(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
if err != nil {
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
<-conn.closeCh
wg.Done()
}()
go func() {
for {
err := conn.Close()
if err != nil {
continue
} else {
return
}
}
}()
wg.Wait()
}

View File

@@ -16,6 +16,13 @@ const (
envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN"
)
func ForcedRelayAddress() string {
if envRelay := os.Getenv("NB_RELAY_ADDRESS"); envRelay != "" {
return envRelay
}
return ""
}
func iceKeepAlive() time.Duration {
keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec)
if keepAliveEnv == "" {

View File

@@ -0,0 +1,195 @@
package peer
import (
"context"
"errors"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
var (
ErrSignalIsNotReady = errors.New("signal is not ready")
)
// IceCredentials ICE protocol credentials struct
type IceCredentials struct {
UFrag string
Pwd string
}
// OfferAnswer represents a session establishment offer or answer
type OfferAnswer struct {
IceCredentials IceCredentials
// WgListenPort is a remote WireGuard listen port.
// This field is used when establishing a direct WireGuard connection without any proxy.
// We can set the remote peer's endpoint with this port.
WgListenPort int
// Version of NetBird Agent
Version string
// RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
// This value is the local Rosenpass server public key when sending the message
RosenpassPubKey []byte
// RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
// This value is the local Rosenpass server address when sending the message
RosenpassAddr string
// relay server address
RelaySrvAddress string
}
type HandshakeArgs struct {
IceUFrag string
IcePwd string
RelayAddr string
}
type Handshaker struct {
mu sync.Mutex
ctx context.Context
log *log.Entry
config ConnConfig
signaler *Signaler
onNewOfferListeners []func(*OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
remoteAnswerCh chan OfferAnswer
lastOfferArgs HandshakeArgs
}
func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler) *Handshaker {
return &Handshaker{
ctx: ctx,
log: log,
config: config,
signaler: signaler,
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
}
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
}
func (h *Handshaker) Listen() {
for {
h.log.Debugf("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
if err != nil {
if _, ok := err.(*ConnectionClosedError); ok {
h.log.Tracef("stop handshaker")
return
}
h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue
}
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer)
}
}
}
func (h *Handshaker) SendOffer(args HandshakeArgs) error {
h.mu.Lock()
defer h.mu.Unlock()
err := h.sendOffer(args)
if err != nil {
return err
}
h.lastOfferArgs = args
return nil
}
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
// todo remove this if signaling can support relay
if ForcedRelayAddress() != "" {
offer.RelaySrvAddress = ForcedRelayAddress()
}
select {
case h.remoteOffersCh <- offer:
return true
default:
h.log.Debugf("OnRemoteOffer skipping message because is not ready")
// connection might not be ready yet to receive so we ignore the message
return false
}
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
// todo remove this if signaling can support relay
if ForcedRelayAddress() != "" {
answer.RelaySrvAddress = ForcedRelayAddress()
}
select {
case h.remoteAnswerCh <- answer:
return true
default:
// connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
return false
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
err := h.sendAnswer()
if err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-h.ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
}
}
// sendOffer prepares local user credentials and signals them to the remote peer
func (h *Handshaker) sendOffer(args HandshakeArgs) error {
offer := OfferAnswer{
IceCredentials: IceCredentials{args.IceUFrag, args.IcePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
RelaySrvAddress: args.RelayAddr,
}
return h.signaler.SignalOffer(offer, h.config.Key)
}
func (h *Handshaker) sendAnswer() error {
h.log.Debugf("sending answer")
answer := OfferAnswer{
IceCredentials: IceCredentials{h.lastOfferArgs.IceUFrag, h.lastOfferArgs.IcePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
RelaySrvAddress: h.lastOfferArgs.RelayAddr,
}
err := h.signaler.SignalAnswer(answer, h.config.Key)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,70 @@
package peer
import (
"github.com/pion/ice/v3"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
)
type Signaler struct {
signal signal.Client
wgPrivateKey wgtypes.Key
}
func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler {
return &Signaler{
signal: signal,
wgPrivateKey: wgPrivateKey,
}
}
func (s *Signaler) SignalOffer(offer OfferAnswer, remoteKey string) error {
return s.signalOfferAnswer(offer, remoteKey, sProto.Body_OFFER)
}
func (s *Signaler) SignalAnswer(offer OfferAnswer, remoteKey string) error {
return s.signalOfferAnswer(offer, remoteKey, sProto.Body_ANSWER)
}
func (s *Signaler) SignalICECandidate(candidate ice.Candidate, remoteKey string) error {
return s.signal.Send(&sProto.Message{
Key: s.wgPrivateKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &sProto.Body{
Type: sProto.Body_CANDIDATE,
Payload: candidate.Marshal(),
},
})
}
func (s *Signaler) Ready() bool {
return s.signal.Ready()
}
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
msg, err := signal.MarshalCredential(
s.wgPrivateKey,
offerAnswer.WgListenPort,
remoteKey,
&signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
bodyType,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress)
if err != nil {
return err
}
err = s.signal.Send(msg)
if err != nil {
return err
}
return nil
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client"
)
// State contains the latest state of a peer
@@ -142,6 +143,8 @@ type Status struct {
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool
relayMgr *relayClient.Manager
}
// NewRecorder returns a new Status instance
@@ -156,6 +159,12 @@ func NewRecorder(mgmAddress string) *Status {
}
}
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.mux.Lock()
defer d.mux.Unlock()
d.relayMgr = manager
}
// ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock()
@@ -503,7 +512,28 @@ func (d *Status) GetSignalState() SignalState {
}
func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
if d.relayMgr == nil {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
relaysState := make([]relay.ProbeResult, len(d.relayStates), len(d.relayStates)+1)
copy(relaysState, d.relayStates)
relayState := relay.ProbeResult{}
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
relayState.URI = d.relayMgr.ServerURL()
relayState.Err = err
} else {
relayState.URI = instanceAddr
}
relaysState = append(relaysState, relayState)
return relaysState
}
func (d *Status) GetDNSStates() []NSGroupState {

View File

@@ -2,8 +2,8 @@ package peer
import (
"errors"
"testing"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
@@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
Mux: new(sync.RWMutex),
Mux: new(sync.RWMutex),
}
status.peers[key] = peerState

View File

@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(conn.config.InterfaceBlackList)
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(w.configICE.InterfaceBlackList)
}

View File

@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.configICE.InterfaceBlackList)
}

View File

@@ -0,0 +1,461 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/pion/ice/v3"
"github.com/pion/randutil"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/route"
)
const (
iceKeepAliveDefault = 4 * time.Second
iceDisconnectedTimeoutDefault = 6 * time.Second
// iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
iceRelayAcceptanceMinWaitDefault = 2 * time.Second
lenUFrag = 16
lenPwd = 32
runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
)
var (
failedTimeout = 6 * time.Second
)
type ICEConfig struct {
// StunTurn is a list of STUN and TURN URLs
StunTurn atomic.Value // []*stun.URI
// InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
// (e.g. if eth0 is in the list, host candidate of this interface won't be used)
InterfaceBlackList []string
DisableIPv6Discovery bool
UDPMux ice.UDPMux
UDPMuxSrflx ice.UniversalUDPMux
NATExternalIPs []string
}
type ICEConnInfo struct {
RemoteConn net.Conn
RosenpassPubKey []byte
RosenpassAddr string
LocalIceCandidateType string
RemoteIceCandidateType string
RemoteIceCandidateEndpoint string
LocalIceCandidateEndpoint string
Direct bool
Relayed bool
RelayedOnLocal bool
}
type WorkerICECallbacks struct {
OnConnReady func(ConnPriority, ICEConnInfo)
OnStatusChanged func(ConnStatus)
}
type WorkerICE struct {
ctx context.Context
log *log.Entry
config ConnConfig
configICE ICEConfig
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
statusRecorder *Status
hasRelayOnLocally bool
conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent
muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string
localPwd string
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, configICE ICEConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
w := &WorkerICE{
ctx: ctx,
log: log,
config: config,
configICE: configICE,
signaler: signaler,
iFaceDiscover: ifaceDiscover,
statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally,
conn: callBacks,
}
localUfrag, localPwd, err := generateICECredentials()
if err != nil {
return nil, err
}
w.localUfrag = localUfrag
w.localPwd = localPwd
return w, nil
}
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE")
w.muxAgent.Lock()
if w.agent != nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = candidateTypesP2P()
} else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = candidateTypes()
}
w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return
}
w.agent = agent
w.muxAgent.Unlock()
w.log.Debugf("gather candidates")
err = w.agent.GatherCandidates()
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err)
return
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err)
return
}
w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
return
}
if !isRelayCandidate(pair.Local) {
// dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
// To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
go w.punchRemoteWGPort(pair, remoteWgPort)
}
ci := ICEConnInfo{
RemoteConn: remoteConn,
RosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
LocalIceCandidateType: pair.Local.Type().String(),
RemoteIceCandidateType: pair.Remote.Type().String(),
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Direct: !isRelayCandidate(pair.Local),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
}
w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(w.selectedPriority, ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
transportNet, err := w.newStdNet()
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
Urls: w.configICE.StunTurn.Load().([]*stun.URI),
CandidateTypes: relaySupport,
InterfaceFilter: stdnet.InterfaceFilter(w.configICE.InterfaceBlackList),
UDPMux: w.configICE.UDPMux,
UDPMuxSrflx: w.configICE.UDPMuxSrflx,
NAT1To1IPs: w.configICE.NATExternalIPs,
Net: transportNet,
FailedTimeout: &failedTimeout,
DisconnectedTimeout: &iceDisconnectedTimeout,
KeepaliveInterval: &iceKeepAlive,
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: w.localUfrag,
LocalPwd: w.localPwd,
}
if w.configICE.DisableIPv6Discovery {
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
}
w.sentExtraSrflx = false
agent, err := ice.NewAgent(agentConfig)
if err != nil {
return nil, err
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
if err != nil {
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
return
}
mux, ok := w.configICE.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
}
_, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
if err != nil {
w.log.Warnf("got an error while sending the punch packet, err: %s", err)
}
}
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer
func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
// nil means candidate gathering has been ended
if candidate == nil {
return
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
w.log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := w.signaler.SignalICECandidate(candidate, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func candidateTypes() []ice.CandidateType {
if hasICEForceRelayConn() {
return []ice.CandidateType{ice.CandidateTypeRelay}
}
// TODO: remove this once we have refactored userspace proxy into the bind package
if runtime.GOOS == "ios" {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
}
func candidateTypesP2P() []ice.CandidateType {
return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}
func isRelayed(pair *ice.CandidatePair) bool {
if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
return true
}
return false
}
func generateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {
return "", "", err
}
pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
if err != nil {
return "", "", err
}
return ufrag, pwd, nil
}

View File

@@ -0,0 +1,100 @@
package peer
import (
"context"
"errors"
"net"
log "github.com/sirupsen/logrus"
relayClient "github.com/netbirdio/netbird/relay/client"
)
type RelayConnInfo struct {
relayedConn net.Conn
rosenpassPubKey []byte
rosenpassAddr string
}
type WorkerRelayCallbacks struct {
OnConnReady func(RelayConnInfo)
OnDisconnected func()
}
type WorkerRelay struct {
ctx context.Context
log *log.Entry
config ConnConfig
relayManager *relayClient.Manager
conn WorkerRelayCallbacks
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, config ConnConfig, relayManager *relayClient.Manager, callbacks WorkerRelayCallbacks) *WorkerRelay {
return &WorkerRelay{
ctx: ctx,
log: log,
config: config,
relayManager: relayManager,
conn: callbacks,
}
}
func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if !w.isRelaySupported(remoteOfferAnswer) {
w.log.Infof("Relay is not supported by remote peer")
return
}
// the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
if err != nil {
w.log.Infof("local Relay connection is lost, skipping connection attempt")
return
}
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key, w.conn.OnDisconnected)
if err != nil {
// todo handle all type errors
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Infof("do not need to reopen relay connection")
return
}
w.log.Infof("do not need to reopen relay connection: %s", err)
return
}
w.log.Debugf("Relay connection established with %s", srv)
go w.conn.OnConnReady(RelayConnInfo{
relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
})
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress()
}
func (w *WorkerRelay) IsController() bool {
return w.config.LocalKey > w.config.Key
}
func (w *WorkerRelay) RelayIsSupportedLocally() bool {
return w.relayManager.HasRelayAddress()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() {
return false
}
return answer.RelaySrvAddress != ""
}
func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
if w.IsController() {
return myRelayAddress
}
return remoteRelayAddress
}

View File

@@ -17,7 +17,7 @@ import (
// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
URI *stun.URI
URI string
Err error
Addr string
}
@@ -176,7 +176,7 @@ func ProbeAll(
wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
}

View File

@@ -42,7 +42,7 @@ type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
wgInterface iface.IWGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
@@ -52,7 +52,7 @@ type clientNetwork struct {
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{

View File

@@ -48,7 +48,7 @@ type DefaultManager struct {
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
wgInterface *iface.WGIface
wgInterface iface.IWGIface
pubKey string
notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
@@ -60,7 +60,7 @@ func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface *iface.WGIface,
wgInterface iface.IWGIface,
statusRecorder *peer.Status,
initialRoutes []*route.Route,
) *DefaultManager {

View File

@@ -11,6 +11,6 @@ import (
"github.com/netbirdio/netbird/iface"
)
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -22,11 +22,11 @@ type defaultServerRouter struct {
ctx context.Context
routes map[route.ID]*route.Route
firewall firewall.Manager
wgInterface *iface.WGIface
wgInterface iface.IWGIface
statusRecorder *peer.Status
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),

View File

@@ -22,7 +22,7 @@ const (
)
// Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface *iface.WGIface) (map[string]int, error) {
func Setup(wgIface iface.IWGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error

View File

@@ -17,10 +17,10 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
wgInterface iface.IWGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
func NewSysOps(wgInterface iface.IWGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}

View File

@@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),

View File

@@ -3,6 +3,7 @@ package wgproxy
import (
"context"
"fmt"
"io"
"net"
log "github.com/sirupsen/logrus"
@@ -64,7 +65,6 @@ func (p *WGUserSpaceProxy) Free() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
@@ -73,11 +73,17 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
default:
n, err := p.localConn.Read(buf)
if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if err == io.EOF {
p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
}
continue
}
}
@@ -96,11 +102,17 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}

View File

@@ -174,8 +174,8 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
for _, relay := range status.Relays {
if relay.URI != nil {
a.AnonymizeURI(relay.URI.String())
if relay.URI != "" {
a.AnonymizeURI(relay.URI)
}
}
}

View File

@@ -745,7 +745,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI.String(),
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {

View File

@@ -124,7 +124,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err

19
encryption/cert.go Normal file
View File

@@ -0,0 +1,19 @@
package encryption
import "crypto/tls"
func LoadTLSConfig(certFile, keyFile string) (*tls.Config, error) {
serverCert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
config := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.NoClientCert,
NextProtos: []string{
"h2", "http/1.1", // enable HTTP/2
},
}
return config, nil
}

View File

@@ -9,7 +9,7 @@ import (
)
// CreateCertManager wraps common logic of generating Let's encrypt certificate.
func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Manager, error) {
func CreateCertManager(datadir string, letsencryptDomain ...string) (*autocert.Manager, error) {
certDir := filepath.Join(datadir, "letsencrypt")
if _, err := os.Stat(certDir); os.IsNotExist(err) {
@@ -24,7 +24,7 @@ func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Mana
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
HostPolicy: autocert.HostWhitelist(letsencryptDomain),
HostPolicy: autocert.HostWhitelist(letsencryptDomain...),
}
return certManager, nil

12
go.mod
View File

@@ -12,7 +12,7 @@ require (
github.com/gorilla/mux v1.8.0
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.23.0
github.com/onsi/gomega v1.27.6
github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3
@@ -62,10 +62,12 @@ require (
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.45.0
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@@ -92,6 +94,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
nhooyr.io/websocket v1.8.11
)
require (
@@ -133,10 +136,12 @@ require (
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
@@ -162,13 +167,13 @@ require (
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -188,9 +193,12 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect

28
go.sum
View File

@@ -50,6 +50,9 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
@@ -154,6 +157,8 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
@@ -211,6 +216,8 @@ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -245,6 +252,7 @@ github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mO
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -354,14 +362,14 @@ github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs=
github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys=
github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
@@ -413,6 +421,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so=
@@ -460,6 +470,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@@ -522,6 +533,8 @@ go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2L
go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0=
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -554,6 +567,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -607,6 +622,7 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -674,6 +690,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -768,5 +786,7 @@ k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
k8s.io/kube-openapi v0.0.0-20191107075043-30be4d16710a/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=

103
iface/iface_moc.go Normal file
View File

@@ -0,0 +1,103 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface/bind"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter PacketFilter) error
GetFilterFunc func() PacketFilter
GetDeviceFunc func() *DeviceWrapper
GetStatsFunc func(peerKey string) (WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *DeviceWrapper {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) {
return m.GetStatsFunc(peerKey)
}

32
iface/iwginterface.go Normal file
View File

@@ -0,0 +1,32 @@
//go:build !windows
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface/bind"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter PacketFilter) error
GetFilter() PacketFilter
GetDevice() *DeviceWrapper
GetStats(peerKey string) (WGStats, error)
}

View File

@@ -0,0 +1,31 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface/bind"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter PacketFilter) error
GetFilter() PacketFilter
GetDevice() *DeviceWrapper
GetStats(peerKey string) (WGStats, error)
GetInterfaceGUIDString() (string, error)
}

View File

@@ -75,7 +75,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)

View File

@@ -195,7 +195,7 @@ var (
return fmt.Errorf("failed to build default manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnRelayTokenManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.RelayConfig)
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
@@ -271,7 +271,7 @@ var (
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnRelayTokenManager, appMetrics, ephemeralManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
@@ -538,6 +538,10 @@ func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config,
}
}
if loadedConfig.RelayConfig != nil {
log.Infof("Relay address: %v", loadedConfig.RelayConfig.Address)
}
return loadedConfig, err
}

View File

@@ -0,0 +1,54 @@
package cmd
import (
"context"
"os"
"testing"
)
const (
exampleConfig = `{
"RelayConfig": {
"Address": "rels://relay.stage.npeer.io"
},
"HttpConfig": {
"AuthAudience": "https://stageapp/",
"AuthIssuer": "https://something.eu.auth0.com/",
"OIDCConfigEndpoint": "https://something.eu.auth0.com/.well-known/openid-configuration"
}
}`
)
func Test_loadMgmtConfig(t *testing.T) {
tmpFile, err := createConfig()
if err != nil {
t.Fatalf("failed to create config: %s", err)
}
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
if err != nil {
t.Fatalf("failed to load management config: %s", err)
}
if cfg.RelayConfig == nil {
t.Fatalf("config is nil")
}
if cfg.RelayConfig.Address == "" {
t.Fatalf("relay address is empty")
}
}
func createConfig() (string, error) {
tmpfile, err := os.CreateTemp("", "config.json")
if err != nil {
return "", err
}
_, err = tmpfile.Write([]byte(exampleConfig))
if err != nil {
return "", err
}
if err := tmpfile.Close(); err != nil {
return "", err
}
return tmpfile.Name(), nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -177,6 +177,8 @@ message WiretrusteeConfig {
// a Signal server config
HostConfig signal = 3;
RelayConfig relay = 4;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@@ -193,6 +195,13 @@ message HostConfig {
DTLS = 4;
}
}
message RelayConfig {
repeated string urls = 1;
string tokenPayload = 2;
string tokenSignature = 3;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {

View File

@@ -32,9 +32,10 @@ const (
// Config of the Management service
type Config struct {
Stuns []*Host
TURNConfig *TURNConfig
Signal *Host
Stuns []*Host
TURNConfig *TURNConfig
RelayConfig *RelayConfig
Signal *Host
Datadir string
DataStoreEncryptionKey string
@@ -71,6 +72,10 @@ type TURNConfig struct {
Turns []*Host
}
type RelayConfig struct {
Address string
}
// HttpServerConfig is a config of the HTTP Management service server
type HttpServerConfig struct {
LetsEncryptDomain string

View File

@@ -16,13 +16,12 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -32,17 +31,17 @@ type GRPCServer struct {
accountManager AccountManager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
config *Config
turnCredentialsManager TURNCredentialsManager
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peersUpdateManager *PeersUpdateManager
config *Config
turnRelayTokenManager TURNRelayTokenManager
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
}
// NewServer creates a new Management server
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnRelayTokenManager TURNRelayTokenManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -88,14 +87,14 @@ func NewServer(ctx context.Context, config *Config, accountManager AccountManage
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
config: config,
turnRelayTokenManager: turnRelayTokenManager,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
}, nil
}
@@ -172,7 +171,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(ctx, peer)
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
s.turnRelayTokenManager.SetupRefresh(ctx, peer.ID)
}
if s.appMetrics != nil {
@@ -235,7 +234,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
s.turnRelayTokenManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
@@ -421,9 +420,14 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
trt, err := s.turnRelayTokenManager.Generate()
if err != nil {
log.Errorf("failed generating TURN and Relay token: %v", err)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, trt),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(ctx, postureChecks),
}
@@ -481,7 +485,7 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
}
}
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
func toWiretrusteeConfig(config *Config, turnCredentials *TURNRelayToken, relayToken *TURNRelayToken) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
@@ -497,8 +501,8 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Username
password = turnCredentials.Password
username = turnCredentials.Payload
password = turnCredentials.Signature
} else {
username = turn.Username
password = turn.Password
@@ -513,6 +517,18 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
})
}
var relayCfg *proto.RelayConfig
if config.RelayConfig != nil && config.RelayConfig.Address != "" {
relayCfg = &proto.RelayConfig{
Urls: []string{config.RelayConfig.Address},
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
return &proto.WiretrusteeConfig{
Stuns: stuns,
Turns: turns,
@@ -520,6 +536,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
},
Relay: relayCfg,
}
}
@@ -547,8 +564,8 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -590,14 +607,15 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
creds := s.turnCredentialsManager.GenerateCredentials()
turnCredentials = &creds
} else {
turnCredentials = nil
var turnCredentials *TURNRelayToken
trt, err := s.turnRelayTokenManager.Generate()
if err != nil {
log.Errorf("failed generating TURN and Relay token: %v", err)
}
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
if s.config.TURNConfig.TimeBasedCredentials {
turnCredentials = trt
}
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@@ -419,7 +419,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)

View File

@@ -546,7 +546,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, "")
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)

View File

@@ -931,7 +931,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, peer)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
}
}

View File

@@ -0,0 +1,132 @@
package server
import (
"context"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
// TURNRelayTokenManager used to manage TURN credentials
type TURNRelayTokenManager interface {
Generate() (*TURNRelayToken, error)
SetupRefresh(ctx context.Context, peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *TURNConfig
relayAddr string
hmacToken *auth.TimedHMAC
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNRelayToken auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayConfig *RelayConfig) *TimeBasedAuthSecretsManager {
var relayAddr string
if relayConfig != nil {
relayAddr = relayConfig.Address
}
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
updateManager: updateManager,
turnCfg: turnCfg,
relayAddr: relayAddr,
hmacToken: auth.NewTimedHMAC(turnCfg.Secret, turnCfg.CredentialsTTL.Duration),
cancelMap: make(map[string]chan struct{}),
}
}
// Generate generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) Generate() (*TURNRelayToken, error) {
token, err := m.hmacToken.GenerateToken()
if err != nil {
return nil, fmt.Errorf("failed to generate token: %s", err)
}
return (*TURNRelayToken)(token), nil
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
defer ticker.Stop()
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTokens(ctx, peerID)
}
}
}()
}
func (m *TimeBasedAuthSecretsManager) pushNewTokens(ctx context.Context, peerID string) {
token, err := m.hmacToken.GenerateToken()
if err != nil {
log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
return
}
var turns []*proto.ProtectedHostConfig
for _, host := range m.turnCfg.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: token.Payload,
Password: token.Signature,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
Relay: &proto.RelayConfig{
Urls: []string{m.relayAddr},
TokenPayload: token.Payload,
TokenSignature: token.Signature,
},
},
}
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}

View File

@@ -27,18 +27,18 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
})
}, "")
credentials := tested.GenerateCredentials()
credentials, _ := tested.Generate()
if credentials.Username == "" {
if credentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty")
}
if credentials.Password == "" {
if credentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty")
}
validateMAC(t, credentials.Username, credentials.Password, []byte(secret))
validateMAC(t, credentials.Payload, credentials.Signature, []byte(secret))
}
@@ -53,7 +53,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
})
}, "")
tested.SetupRefresh(context.Background(), peer)
@@ -101,7 +101,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
CredentialsTTL: ttl,
Secret: secret,
Turns: []*Host{TurnTestHost},
})
}, "")
tested.SetupRefresh(context.Background(), peer)
if _, ok := tested.cancelMap[peer]; !ok {

View File

@@ -1,126 +0,0 @@
package server
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
)
// TURNCredentialsManager used to manage TURN credentials
type TURNCredentialsManager interface {
GenerateCredentials() TURNCredentials
SetupRefresh(ctx context.Context, peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
config *TURNConfig
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNCredentials struct {
Username string
Password string
}
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
config: config,
updateManager: updateManager,
cancelMap: make(map[string]chan struct{}),
}
}
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
mac := hmac.New(sha1.New, []byte(m.config.Secret))
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
username := fmt.Sprint(timeAuth)
_, err := mac.Write([]byte(username))
if err != nil {
log.Errorln("Generating turn password failed with error: ", err)
}
bytePassword := mac.Sum(nil)
password := base64.StdEncoding.EncodeToString(bytePassword)
return TURNCredentials{
Username: username,
Password: password,
}
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
c := m.GenerateCredentials()
var turns []*proto.ProtectedHostConfig
for _, host := range m.config.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: c.Username,
Password: c.Password,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
},
}
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}
}
}()
}

9
relay/auth/allow_all.go Normal file
View File

@@ -0,0 +1,9 @@
package auth
// AllowAllAuth is a Validator that allows all connections.
type AllowAllAuth struct {
}
func (a *AllowAllAuth) Validate(any) error {
return nil
}

34
relay/auth/hmac/store.go Normal file
View File

@@ -0,0 +1,34 @@
package hmac
import (
"sync"
log "github.com/sirupsen/logrus"
)
// TokenStore is a simple in-memory store for token
// With this can update the token in thread safe way
type TokenStore struct {
mu sync.Mutex
token []byte
}
func (a *TokenStore) UpdateToken(token *Token) {
a.mu.Lock()
defer a.mu.Unlock()
if token == nil {
return
}
t, err := marshalToken(*token)
if err != nil {
log.Errorf("failed to marshal token: %s", err)
}
a.token = t
}
func (a *TokenStore) TokenBinary() []byte {
a.mu.Lock()
defer a.mu.Unlock()
return a.token
}

102
relay/auth/hmac/token.go Normal file
View File

@@ -0,0 +1,102 @@
package hmac
import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/gob"
"fmt"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
type Token struct {
Payload string
Signature string
}
func marshalToken(token Token) ([]byte, error) {
buffer := bytes.NewBuffer([]byte{})
encoder := gob.NewEncoder(buffer)
err := encoder.Encode(token)
if err != nil {
log.Errorf("failed to marshal token: %s", err)
return nil, err
}
return buffer.Bytes(), nil
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)
decoder := gob.NewDecoder(buffer)
err := decoder.Decode(&creds)
return creds, err
}
// TimedHMAC generates token with TTL and using pre-shared secret known to TURN server
type TimedHMAC struct {
secret string
timeToLive time.Duration
}
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
return &TimedHMAC{
secret: secret,
timeToLive: timeToLive,
}
}
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimedHMAC) GenerateToken() (*Token, error) {
timeAuth := time.Now().Add(m.timeToLive).Unix()
timeStamp := fmt.Sprint(timeAuth)
checksum, err := m.generate(timeStamp)
if err != nil {
return nil, err
}
return &Token{
Payload: timeStamp,
Signature: base64.StdEncoding.EncodeToString(checksum),
}, nil
}
func (m *TimedHMAC) Validate(token Token) error {
expectedMAC, err := m.generate(token.Payload)
if err != nil {
return err
}
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
return fmt.Errorf("signature mismatch")
}
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %s", err)
}
if time.Now().Unix() > timeAuthInt {
return fmt.Errorf("expired token")
}
return nil
}
func (m *TimedHMAC) generate(payload string) ([]byte, error) {
mac := hmac.New(sha1.New, []byte(m.secret))
_, err := mac.Write([]byte(payload))
if err != nil {
log.Errorf("failed to generate token: %s", err)
return nil, err
}
return mac.Sum(nil), nil
}

View File

@@ -0,0 +1,103 @@
package hmac
import (
"encoding/base64"
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "secret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if creds.Payload == "" {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(creds.Payload, 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
_, err = base64.StdEncoding.DecodeString(creds.Signature)
if err != nil {
t.Fatalf("expected signature to be base64 encoded, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
// Test valid token
creds, err := manager.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err := manager.Validate(*creds); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
creds, err := manager.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
invalidCreds := &Token{
Payload: creds.Payload,
Signature: "invalidsignature",
}
if err = manager.Validate(*invalidCreds); err == nil {
t.Fatalf("expected invalid token due to signature mismatch")
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
v := NewTimedHMAC(secret, -1*time.Hour)
expiredCreds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err = v.Validate(*expiredCreds); err == nil {
t.Fatalf("expected invalid token due to expiration")
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
// Test invalid payload
invalidPayloadCreds := &Token{
Payload: "invalidtimestamp",
Signature: creds.Signature,
}
if err = v.Validate(*invalidPayloadCreds); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,27 @@
package hmac
import (
log "github.com/sirupsen/logrus"
"time"
)
type TimedHMACValidator struct {
*TimedHMAC
}
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
ta := NewTimedHMAC(secret, duration)
return &TimedHMACValidator{
ta,
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
b := credentials.([]byte)
c, err := unmarshalToken(b)
if err != nil {
log.Errorf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(c)
}

5
relay/auth/validator.go Normal file
View File

@@ -0,0 +1,5 @@
package auth
type Validator interface {
Validate(any) error
}

520
relay/client/client.go Normal file
View File

@@ -0,0 +1,520 @@
package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 8820
serverResponseTimeout = 8 * time.Second
)
var (
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
)
type internalStopFlag struct {
sync.Mutex
stop bool
}
func newInternalStopFlag() *internalStopFlag {
return &internalStopFlag{}
}
func (isf *internalStopFlag) set() {
isf.Lock()
defer isf.Unlock()
isf.stop = true
}
func (isf *internalStopFlag) isSet() bool {
isf.Lock()
defer isf.Unlock()
return isf.stop
}
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
type Msg struct {
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
type connContainer struct {
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct {
log *log.Entry
parentCtx context.Context
connectionURL string
authTokenStore *auth.TokenStore
hashedID []byte
bufPool *sync.Pool
relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup
instanceURL string
muInstanceURL sync.Mutex
onDisconnectListener func()
listenerMutex sync.Mutex
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
}
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.log.Infof("connecting to relay server: %s", c.connectionURL)
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.serviceIsRunning {
return nil
}
err := c.connect()
if err != nil {
return err
}
c.serviceIsRunning = true
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
log.Infof("relay connection established with: %s", c.connectionURL)
return nil
}
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately.
// todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.serviceIsRunning {
return nil, fmt.Errorf("relay connection is not established")
}
hashedID, hashedStringID := messages.HashID(dstPeerID)
_, ok := c.conns[hashedStringID]
if ok {
return nil, ErrConnAlreadyExists
}
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
func (c *Client) ServerInstanceURL() (string, error) {
c.muInstanceURL.Lock()
defer c.muInstanceURL.Unlock()
if c.instanceURL == "" {
return "", fmt.Errorf("relay connection is not established")
}
return c.instanceURL, nil
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
}
// HasConns returns true if there are connections.
func (c *Client) HasConns() bool {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.conns) > 0
}
// Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error {
return c.close(true)
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.connectionURL)
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
}
c.relayConn = nil
return err
}
return nil
}
func (c *Client) handShake() error {
tb := c.authTokenStore.TokenBinary()
msg, err := messages.MarshalHelloMsg(c.hashedID, tb)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeSize)
n, err := c.readWithTimeout(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
return err
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
ia, err := messages.UnmarshalHelloResponse(buf[:n])
if err != nil {
return err
}
c.muInstanceURL.Lock()
c.instanceURL = ia
c.muInstanceURL.Unlock()
return nil
}
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver()
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
errExit error
n int
)
for {
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
break
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
continue
}
if !c.handleMsg(msgType, buf[:n], bufPtr, hc, internallyStoppedFlag) {
break
}
}
hc.Stop()
c.muInstanceURL.Lock()
c.instanceURL = ""
c.muInstanceURL.Unlock()
c.notifyDisconnected()
c.wgReadLoop.Done()
_ = c.close(false)
}
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
switch msgType {
case messages.MsgTypeHealthCheck:
c.handleHealthCheck(hc, internallyStoppedFlag)
c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
return true
}
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
msg := messages.MarshalHealthcheck()
_, wErr := c.relayConn.Write(msg)
if wErr != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to send heartbeat: %s", wErr)
}
}
hc.Heartbeat()
}
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
if err != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to parse transport message: %v", err)
}
c.bufPool.Put(bufPtr)
return true
}
stringID := messages.HashIDToString(peerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
c.bufPool.Put(bufPtr)
return false
}
container, ok := c.conns[stringID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", stringID)
c.bufPool.Put(bufPtr)
return true
}
msg := Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload,
}
container.writeMsg(msg)
return true
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
// conn, ok := c.conns[id]
_, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
}
/*
if conn != clientRef {
return 0, io.EOF
}
*/
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
n, err := c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
}
return n, err
}
func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
case _, ok := <-hc.OnTimeout:
if !ok {
return
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
log.Errorf("failed to teardown connection: %s", err)
}
return
}
}
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
container.close()
}
c.conns = make(map[string]*connContainer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) closeConn(id string) error {
c.mu.Lock()
defer c.mu.Unlock()
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
container.close()
delete(c.conns, id)
return nil
}
func (c *Client) close(gracefullyExit bool) error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
return nil
}
c.serviceIsRunning = false
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
err = c.relayConn.Close()
}
c.mu.Unlock()
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.connectionURL)
return err
}
func (c *Client) notifyDisconnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener()
}
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}
func (c *Client) readWithTimeout(buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
defer cancel()
readDone := make(chan struct{})
var (
n int
err error
)
go func() {
n, err = c.relayConn.Read(buf)
close(readDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("read operation timed out")
case <-readDone:
return n, err
}
}

603
relay/client/client_test.go Normal file
View File

@@ -0,0 +1,603 @@
package client
import (
"context"
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server"
)
var (
av = &auth.AllowAllAuth{}
hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestClient(t *testing.T) {
ctx := context.Background()
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
listenCfg := server.ListenerConfig{Address: serverListenAddr}
err := srv.Listen(listenCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
t.Log("alice connecting to server")
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
t.Log("Bob connecting to server")
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()
t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn("bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn("alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
log.Debugf("alice sent message to bob")
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
log.Debugf("on new message from alice to bob")
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestRegistration(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect()
if err != nil {
_ = srv.Close()
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}
func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect()
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
log.Debugf("%s", err)
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}
func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client Alice")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chAlice, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err != nil {
t.Errorf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := chAlice.Read(buf)
if err != nil {
t.Errorf("failed to read from channel: %s", err)
}
if testString != string(buf[:n]) {
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestCloseConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = clientAlice.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1 := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv := server.NewServer(serverURL, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:
if err != nil {
return err
}
case <-time.After(300 * time.Millisecond):
return nil
}
return nil
}

67
relay/client/conn.go Normal file
View File

@@ -0,0 +1,67 @@
package client
import (
"io"
"net"
"time"
)
type Conn struct {
client *Client
dstID []byte
dstStringID string
messageChan chan Msg
}
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
c := &Conn{
client: client,
dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan,
}
return c
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c.dstStringID, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {
return c.client.closeConn(c.dstStringID)
}
func (c *Conn) LocalAddr() net.Addr {
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.client.relayConn.RemoteAddr()
}
func (c *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("SetDeadline is not implemented")
}
func (c *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}

View File

@@ -0,0 +1,52 @@
package quic
import (
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Conn struct {
quic.Stream
qConn quic.Connection
}
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
return &Conn{
Stream: stream,
qConn: qConn,
}
}
func (q *Conn) Write(b []byte) (n int, err error) {
log.Debugf("writing: %d, %x\n", len(b), b)
n, err = q.Stream.Write(b)
if n != len(b) {
log.Errorf("failed to write out the full message")
}
return
}
func (q *Conn) Close() error {
err := q.Stream.Close()
if err != nil {
log.Errorf("failed to close stream: %s", err)
return err
}
err = q.qConn.CloseWithError(0, "")
if err != nil {
log.Errorf("failed to close connection: %s", err)
return err
}
return err
}
func (c *Conn) LocalAddr() net.Addr {
return c.qConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.qConn.RemoteAddr()
}

View File

@@ -0,0 +1,32 @@
package quic
import (
"context"
"crypto/tls"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
func Dial(address string) (net.Conn, error) {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"quic-echo-example"},
}
qConn, err := quic.DialAddr(context.Background(), address, tlsConf, &quic.Config{
EnableDatagrams: true,
})
if err != nil {
log.Errorf("dial quic address %s failed: %s", address, err)
return nil, err
}
stream, err := qConn.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
conn := NewConn(stream, qConn)
return conn, nil
}

View File

@@ -0,0 +1,7 @@
package tcp
import "net"
func Dial(address string) (net.Conn, error) {
return net.Dial("tcp", address)
}

View File

@@ -0,0 +1,14 @@
package udp
import (
"net"
)
func Dial(address string) (net.Conn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
return net.DialUDP("udp", nil, udpAddr)
}

View File

@@ -0,0 +1,12 @@
package ws
type WebsocketAddr struct {
}
func (a WebsocketAddr) Network() string {
return "websocket"
}
func (a WebsocketAddr) String() string {
return "websocket/unknown-addr"
}

View File

@@ -0,0 +1,64 @@
package ws
import (
"context"
"fmt"
"net"
"time"
"nhooyr.io/websocket"
)
type Conn struct {
ctx context.Context
*websocket.Conn
}
func NewConn(wsConn *websocket.Conn) net.Conn {
return &Conn{
ctx: context.Background(),
Conn: wsConn,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
return 0, err
}
if t != websocket.MessageBinary {
return 0, fmt.Errorf("unexpected message type")
}
return ioReader.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) RemoteAddr() net.Addr {
return WebsocketAddr{}
}
func (c *Conn) LocalAddr() net.Addr {
return WebsocketAddr{}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
return c.Conn.CloseNow()
}

View File

@@ -0,0 +1,59 @@
package ws
import (
"context"
"fmt"
"net"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
nbnet "github.com/netbirdio/netbird/util/net"
)
func Dial(address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
}
opts := &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
wsConn, resp, err := websocket.Dial(context.Background(), wsURL, opts)
if err != nil {
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}
if resp.Body != nil {
_ = resp.Body.Close()
}
conn := NewConn(wsConn)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
customDialer := nbnet.NewDialer()
customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return customDialer.DialContext(ctx, network, addr)
},
}
return &http.Client{
Transport: customTransport,
}
}

44
relay/client/guard.go Normal file
View File

@@ -0,0 +1,44 @@
package client
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
reconnectingTimeout = 5 * time.Second
)
type Guard struct {
ctx context.Context
relayClient *Client
}
func NewGuard(context context.Context, relayClient *Client) *Guard {
g := &Guard{
ctx: context,
relayClient: relayClient,
}
return g
}
func (g *Guard) OnDisconnected() {
ticker := time.NewTicker(reconnectingTimeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err := g.relayClient.Connect()
if err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
continue
}
return
case <-g.ctx.Done():
return
}
}
}

275
relay/client/manager.go Normal file
View File

@@ -0,0 +1,275 @@
package client
import (
"context"
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
)
var (
relayCleanupInterval = 60 * time.Second
errRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
// RelayTrack hold the relay clients for the foreign relay servers.
// With the mutex can ensure we can open new connection in case the relay connection has been established with
// the relay server.
type RelayTrack struct {
sync.RWMutex
relayClient *Client
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
// Manager is a manager for the relay client. It establish one persistent connection to the given relay server. In case
// of network error the manager will try to reconnect to the server.
// The manager also manage temproary relay connection. If a client wants to communicate with an another client on a
// different relay server, the manager will establish a new connection to the relay server. The connection with these
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
serverURL string
peerID string
tokenStore *relayAuth.TokenStore
relayClient *Client
reconnectGuard *Guard
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]map[*func()]struct{}
listenerLock sync.Mutex
}
func NewManager(ctx context.Context, serverURL string, peerID string) *Manager {
return &Manager{
ctx: ctx,
serverURL: serverURL,
peerID: peerID,
tokenStore: &relayAuth.TokenStore{},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]map[*func()]struct{}),
}
}
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop.
func (m *Manager) Serve() error {
if m.relayClient != nil {
return fmt.Errorf("manager already serving")
}
m.relayClient = NewClient(m.ctx, m.serverURL, m.tokenStore, m.peerID)
err := m.relayClient.Connect()
if err != nil {
log.Errorf("failed to connect to relay server: %s", err)
return err
}
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(m.serverURL)
})
m.startCleanupLoop()
return nil
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server.
func (m *Manager) OpenConn(serverAddress, peerKey string, onClosedListener func()) (net.Conn, error) {
if m.relayClient == nil {
return nil, errRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return nil, err
}
var (
netConn net.Conn
)
if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(peerKey)
} else {
log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(serverAddress, peerKey)
}
if err != nil {
return nil, err
}
if onClosedListener != nil {
m.addListener(serverAddress, onClosedListener)
}
return netConn, err
}
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is lost.
// This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) {
if m.relayClient == nil {
return "", errRelayClientNotConnected
}
return m.relayClient.ServerInstanceURL()
}
// ServerURL returns the address of the permanent relay server.
func (m *Manager) ServerURL() string {
return m.serverURL
}
func (m *Manager) HasRelayAddress() bool {
return m.serverURL != ""
}
func (m *Manager) UpdateToken(token *relayAuth.Token) {
m.tokenStore.UpdateToken(token)
}
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
m.relayClientsMutex.RUnlock()
// if not, establish a new connection but check it again (because changed the lock type) before starting the
// connection
m.relayClientsMutex.Lock()
rt, ok = m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
// create a new relay client and store it in the relayClients map
rt = NewRelayTrack()
rt.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect()
if err != nil {
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(serverAddress)
})
rt.relayClient = relayClient
rt.Unlock()
conn, err := relayClient.OpenConn(peerKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (m *Manager) onServerDisconnected(serverAddress string) {
if serverAddress == m.serverURL {
m.reconnectGuard.OnDisconnected()
}
m.notifyOnDisconnectListeners(serverAddress)
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
return rAddr != address, nil
}
func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}
}()
}
func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
for addr, rt := range m.relayClients {
rt.Lock()
if rt.relayClient.HasConns() {
rt.Unlock()
continue
}
rt.relayClient.SetOnDisconnectListener(nil)
go func() {
_ = rt.relayClient.Close()
}()
log.Debugf("clean up unused relay server connection: %s", addr)
delete(m.relayClients, addr)
rt.Unlock()
}
}
func (m *Manager) addListener(serverAddress string, onClosedListener func()) {
m.listenerLock.Lock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
l = make(map[*func()]struct{})
}
l[&onClosedListener] = struct{}{}
m.onDisconnectedListeners[serverAddress] = l
m.listenerLock.Unlock()
}
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Lock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
return
}
for f := range l {
go (*f)()
}
delete(m.onDisconnectedListeners, serverAddress)
m.listenerLock.Unlock()
}

View File

@@ -0,0 +1,337 @@
package client
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server"
)
func TestForeignConn(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1 := server.NewServer(srvCfg1.Address, false, av)
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2 := server.NewServer(srvCfg2.Address, false, av)
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
err := clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob, nil)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice, nil)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1 := server.NewServer(srvCfg1.Address, false, av)
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2 := server.NewServer(srvCfg2.Address, false, av)
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err := mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1 := server.NewServer(srvCfg1.Address, false, av)
errChan := make(chan error, 1)
go func() {
t.Log("binding server 1.")
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
t.Logf("closing server 1.")
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2 := server.NewServer(srvCfg2.Address, false, av)
errChan2 := make(chan error, 1)
go func() {
t.Log("binding server 2.")
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err := mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(toURL(srvCfg2), "anotherpeer", nil)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
}
srv := server.NewServer(srvCfg.Address, false, av)
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err := clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
ra, err := clientAlice.RelayInstanceAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra, "bob", nil)
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra, "bob", nil)
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}
func toURL(address server.ListenerConfig) string {
return "rel://" + address.Address
}

128
relay/cmd/main.go Normal file
View File

@@ -0,0 +1,128 @@
package main
import (
"crypto/tls"
"fmt"
"os"
"os/signal"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/util"
)
var (
listenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port
exposedAddress string
letsencryptDataDir string
letsencryptDomains []string
tlsCertFile string
tlsKeyFile string
authSecret string
rootCmd = &cobra.Command{
Use: "relay",
Short: "Relay service",
Long: "Relay service for Netbird agents",
Run: execute,
}
)
func init() {
_ = util.InitLog("trace", "console")
rootCmd.PersistentFlags().StringVarP(&listenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&exposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
rootCmd.PersistentFlags().StringVarP(&letsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringArrayVarP(&letsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVarP(&tlsCertFile, "tls-cert-file", "c", "", "")
rootCmd.PersistentFlags().StringVarP(&tlsKeyFile, "tls-key-file", "k", "", "")
rootCmd.PersistentFlags().StringVarP(&authSecret, "auth-secret", "s", "", "log level")
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, args []string) {
if exposedAddress == "" {
log.Errorf("exposed address is required")
os.Exit(1)
}
if authSecret == "" {
log.Errorf("auth secret is required")
os.Exit(1)
}
srvListenerCfg := server.ListenerConfig{
Address: listenAddress,
}
if hasLetsEncrypt() {
tlsCfg, err := setupTLSCertManager()
if err != nil {
log.Errorf("%s", err)
os.Exit(1)
}
srvListenerCfg.TLSConfig = tlsCfg
} else if hasCertConfig() {
tlsCfg, err := encryption.LoadTLSConfig(tlsCertFile, tlsKeyFile)
if err != nil {
log.Errorf("%s", err)
os.Exit(1)
}
srvListenerCfg.TLSConfig = tlsCfg
}
tlsSupport := srvListenerCfg.TLSConfig != nil
authenticator := auth.NewTimedHMACValidator(authSecret, 24*time.Hour)
srv := server.NewServer(exposedAddress, tlsSupport, authenticator)
log.Infof("server will be available on: %s", srv.InstanceURL())
err := srv.Listen(srvListenerCfg)
if err != nil {
log.Errorf("failed to bind server: %s", err)
os.Exit(1)
}
waitForExitSignal()
err = srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
os.Exit(1)
}
}
func hasCertConfig() bool {
return tlsCertFile != "" && tlsKeyFile != ""
}
func hasLetsEncrypt() bool {
return letsencryptDataDir != "" && letsencryptDomains != nil && len(letsencryptDomains) > 0
}
func setupTLSCertManager() (*tls.Config, error) {
certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...)
if err != nil {
return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
return certManager.TLSConfig(), nil
}
func main() {
err := rootCmd.Execute()
if err != nil {
os.Exit(1)
}
}

View File

@@ -0,0 +1,83 @@
package healthcheck
import (
"context"
"time"
)
var (
heartbeatTimeout = healthCheckInterval + 3*time.Second
)
// Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
OnTimeout chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver() *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
OnTimeout: make(chan struct{}, 1),
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
}
go r.waitForHealthcheck()
return r
}
// Heartbeat acknowledge the heartbeat has been received
func (r *Receiver) Heartbeat() {
select {
case r.heartbeat <- struct{}{}:
default:
}
}
// Stop check the timeout and do not send new notifications
func (r *Receiver) Stop() {
r.ctxCancel()
}
func (r *Receiver) waitForHealthcheck() {
ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop()
defer r.ctxCancel()
defer close(r.OnTimeout)
for {
select {
case <-r.heartbeat:
r.alive = true
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
r.notifyTimeout()
return
case <-r.ctx.Done():
return
}
}
}
func (r *Receiver) notifyTimeout() {
select {
case r.OnTimeout <- struct{}{}:
default:
}
}

View File

@@ -0,0 +1,42 @@
package healthcheck
import (
"testing"
"time"
)
func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second
r := NewReceiver()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(1 * time.Second):
}
}
func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second
r := NewReceiver()
select {
case <-r.OnTimeout:
case <-time.After(2 * time.Second):
t.Error("timeout not received")
}
}
func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second
r := NewReceiver()
r.Heartbeat()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(3 * time.Second):
}
}

View File

@@ -0,0 +1,68 @@
package healthcheck
import (
"context"
"time"
)
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 5 * time.Second
)
// Sender is a healthcheck sender
// It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
HealthCheck chan struct{}
Timeout chan struct{}
ctx context.Context
ack chan struct{}
}
// NewSender creates a new healthcheck sender
func NewSender(ctx context.Context) *Sender {
hc := &Sender{
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ctx: ctx,
ack: make(chan struct{}, 1),
}
go hc.healthCheck()
return hc
}
func (hc *Sender) OnHCResponse() {
select {
case hc.ack <- struct{}{}:
default:
}
}
func (hc *Sender) healthCheck() {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
defer timeoutTimer.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
case <-timeoutTimer.C:
hc.Timeout <- struct{}{}
return
case <-hc.ack:
timeoutTimer.Stop()
case <-hc.ctx.Done():
return
}
}
}

View File

@@ -0,0 +1,64 @@
package healthcheck
import (
"context"
"testing"
"time"
)
func TestMain(m *testing.M) {
// override the health check interval to speed up the test
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 100 * time.Millisecond
m.Run()
}
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(ctx)
iterations := 0
for i := 0; i < 3; i++ {
select {
case <-hc.HealthCheck:
iterations++
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
}
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(ctx)
select {
case <-hc.Timeout:
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out")
}
}
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSender(ctx)
time.Sleep(300 * time.Millisecond)
cancel()
select {
case <-hc.HealthCheck:
t.Fatalf("health check on received")
case <-hc.Timeout:
t.Fatalf("health check timedout")
case <-ctx.Done():
// expected
case <-time.After(1 * time.Second):
t.Fatalf("is not exited")
}
}

29
relay/messages/id.go Normal file
View File

@@ -0,0 +1,29 @@
package messages
import (
"crypto/sha256"
"encoding/base64"
"fmt"
)
const (
prefixLength = 4
IDSize = prefixLength + sha256.Size
)
var (
prefix = []byte("sha-") // 4 bytes
)
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID))
idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
var prefixedHash []byte
prefixedHash = append(prefixedHash, prefix...)
prefixedHash = append(prefixedHash, idHash[:]...)
return prefixedHash, idHashString
}
func HashIDToString(idHash []byte) string {
return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
}

13
relay/messages/id_test.go Normal file
View File

@@ -0,0 +1,13 @@
package messages
import (
"testing"
)
func TestHashID(t *testing.T) {
hashedID, hashedStringId := HashID("alice")
enc := HashIDToString(hashedID)
if enc != hashedStringId {
t.Errorf("expected %s, got %s", hashedStringId, enc)
}
}

198
relay/messages/message.go Normal file
View File

@@ -0,0 +1,198 @@
package messages
import (
"bytes"
"encoding/gob"
"fmt"
log "github.com/sirupsen/logrus"
)
const (
MsgTypeHello MsgType = 0
MsgTypeHelloResponse MsgType = 1
MsgTypeTransport MsgType = 2
MsgTypeClose MsgType = 3
MsgTypeHealthCheck MsgType = 4
sizeOfMsgType = 1
sizeOfMagicBye = 4
headerSizeTransport = sizeOfMsgType + IDSize // 1 byte for msg type, IDSize for peerID
headerSizeHello = sizeOfMsgType + sizeOfMagicBye + IDSize // 1 byte for msg type, 4 byte for magic header, IDSize for peerID
MaxHandshakeSize = 8192
)
var (
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
healthCheckMsg = []byte{byte(MsgTypeHealthCheck)}
)
type MsgType byte
func (m MsgType) String() string {
switch m {
case MsgTypeHello:
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
return "close"
default:
return "unknown"
}
}
type HelloResponse struct {
InstanceAddress string
}
func DetermineClientMsgType(msg []byte) (MsgType, error) {
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHello:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
case MsgTypeClose:
return msgType, nil
case MsgTypeHealthCheck:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
}
}
func DetermineServerMsgType(msg []byte) (MsgType, error) {
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHelloResponse:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
case MsgTypeClose:
return msgType, nil
case MsgTypeHealthCheck:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
}
}
// MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
// 5 = 1 byte for msg type, 4 byte for magic header
msg := make([]byte, 5, headerSizeHello+len(additions))
msg[0] = byte(MsgTypeHello)
copy(msg[1:5], magicHeader)
msg = append(msg, peerID...)
msg = append(msg, additions...)
return msg, nil
}
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
return nil, nil, fmt.Errorf("invalid 'hello' messge")
}
if !bytes.Equal(msg[1:5], magicHeader) {
return nil, nil, fmt.Errorf("invalid magic header")
}
return msg[5 : 5+IDSize], msg[headerSizeHello:], nil
}
func MarshalHelloResponse(DomainAddress string) ([]byte, error) {
payload := HelloResponse{
InstanceAddress: DomainAddress,
}
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(payload)
if err != nil {
log.Errorf("failed to gob encode hello response: %s", err)
return nil, err
}
msg := make([]byte, 1, 1+buf.Len())
msg[0] = byte(MsgTypeHelloResponse)
msg = append(msg, buf.Bytes()...)
return msg, nil
}
func UnmarshalHelloResponse(msg []byte) (string, error) {
if len(msg) < 2 {
return "", fmt.Errorf("invalid 'hello response' message")
}
payload := HelloResponse{}
buf := bytes.NewBuffer(msg[1:])
dec := gob.NewDecoder(buf)
err := dec.Decode(&payload)
if err != nil {
log.Errorf("failed to gob decode hello response: %s", err)
return "", err
}
return payload.InstanceAddress, nil
}
// Close message
func MarshalCloseMsg() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgTypeClose)
return healthCheckMsg
}
// Transport message
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}
msg := make([]byte, headerSizeTransport, headerSizeTransport+len(payload))
msg[0] = byte(MsgTypeTransport)
copy(msg[1:], peerID)
msg = append(msg, payload...)
return msg, nil
}
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
if len(buf) < headerSizeTransport {
return nil, nil, ErrInvalidMessageLength
}
return buf[1:headerSizeTransport], buf[headerSizeTransport:], nil
}
func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport {
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSizeTransport, buf)
return nil, ErrInvalidMessageLength
}
return buf[1:headerSizeTransport], nil
}
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < 1+len(peerID) {
return ErrInvalidMessageLength
}
copy(msg[1:], peerID)
return nil
}
// health check message
func MarshalHealthcheck() []byte {
return healthCheckMsg
}

View File

@@ -0,0 +1,43 @@
package messages
import (
"testing"
)
func TestMarshalHelloMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
bHello, err := MarshalHelloMsg(peerID, nil)
if err != nil {
t.Fatalf("error: %v", err)
}
receivedPeerID, _, err := UnmarshalHelloMsg(bHello)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(receivedPeerID) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload)
if err != nil {
t.Fatalf("error: %v", err)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if string(id) != string(peerID) {
t.Errorf("expected %s, got %s", peerID, id)
}
if string(respPayload) != string(payload) {
t.Errorf("expected %s, got %s", payload, respPayload)
}
}

View File

@@ -0,0 +1,8 @@
package listener
import "net"
type Listener interface {
Listen(func(conn net.Conn)) error
Close() error
}

View File

@@ -0,0 +1,36 @@
package quic
import (
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Conn struct {
quic.Stream
qConn quic.Connection
}
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
return &Conn{
Stream: stream,
qConn: qConn,
}
}
func (q Conn) Write(b []byte) (n int, err error) {
n, err = q.Stream.Write(b)
if n != len(b) {
log.Errorf("failed to write out the full message")
}
return
}
func (q Conn) LocalAddr() net.Addr {
return q.qConn.LocalAddr()
}
func (q Conn) RemoteAddr() net.Addr {
return q.qConn.RemoteAddr()
}

View File

@@ -0,0 +1,110 @@
package quic
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"net"
"sync"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
listener *quic.Listener
quit chan struct{}
wg sync.WaitGroup
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
ql, err := quic.ListenAddr(l.address, generateTLSConfig(), &quic.Config{
EnableDatagrams: true,
})
if err != nil {
return err
}
l.listener = ql
l.quit = make(chan struct{})
log.Infof("quic server is listening on address: %s", l.address)
l.wg.Add(1)
go l.acceptLoop(onAcceptFn)
<-l.quit
return nil
}
func (l *Listener) Close() error {
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) acceptLoop(acceptFn func(conn net.Conn)) {
defer l.wg.Done()
for {
qConn, err := l.listener.Accept(context.Background())
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
log.Infof("new connection from: %s", qConn.RemoteAddr())
stream, err := qConn.AcceptStream(context.Background())
if err != nil {
log.Errorf("failed to open stream: %s", err)
continue
}
conn := NewConn(stream, qConn)
go acceptFn(conn)
}
}
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"quic-echo-example"},
}
}

View File

@@ -0,0 +1,80 @@
package tcp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
// Listener
// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, and
// it does not handle framing.
type Listener struct {
address string
onAcceptFn func(conn net.Conn)
wg sync.WaitGroup
quit chan struct{}
listener net.Listener
lock sync.Mutex
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
li, err := net.Listen("tcp", l.address)
if err != nil {
log.Errorf("failed to listen on address: %s, %s", l.address, err)
l.lock.Unlock()
return err
}
log.Debugf("TCP server is listening on address: %s", l.address)
l.listener = li
l.wg.Add(1)
go l.acceptLoop()
l.lock.Unlock()
<-l.quit
return nil
}
// Close todo: prevent multiple call (do not close two times the channel)
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) acceptLoop() {
defer l.wg.Done()
for {
conn, err := l.listener.Accept()
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
go l.onAcceptFn(conn)
}
}

View File

@@ -0,0 +1,68 @@
package udp
import (
"io"
"net"
"time"
)
type Conn struct {
*net.UDPConn
addr *net.UDPAddr
msgChannel chan []byte
}
func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *Conn {
return &Conn{
UDPConn: conn,
addr: addr,
msgChannel: make(chan []byte),
}
}
func (u *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-u.msgChannel
if !ok {
return 0, io.EOF
}
n = copy(b, msg)
return n, nil
}
func (u *Conn) Write(b []byte) (n int, err error) {
return u.UDPConn.WriteTo(b, u.addr)
}
func (u *Conn) Close() error {
//TODO implement me
//panic("implement me")
return nil
}
func (u *Conn) LocalAddr() net.Addr {
return u.UDPConn.LocalAddr()
}
func (u *Conn) RemoteAddr() net.Addr {
return u.addr
}
func (u *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *Conn) onNewMsg(b []byte) {
u.msgChannel <- b
}

View File

@@ -0,0 +1,104 @@
package udp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
conns map[string]*Conn
onAcceptFn func(conn net.Conn)
listener *net.UDPConn
wg sync.WaitGroup
quit chan struct{}
lock sync.Mutex
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
conns: make(map[string]*Conn),
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
addr, err := net.ResolveUDPAddr("udp", l.address)
if err != nil {
log.Errorf("invalid listen address '%s': %s", l.address, err)
l.lock.Unlock()
return err
}
li, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("%s", err)
l.lock.Unlock()
return err
}
log.Debugf("udp server is listening on address: %s", addr.String())
l.listener = li
l.wg.Add(1)
go l.readLoop()
l.lock.Unlock()
<-l.quit
return nil
}
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.listener == nil {
return nil
}
log.Infof("closing UDP listener")
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
l.listener = nil
return err
}
func (l *Listener) readLoop() {
defer l.wg.Done()
for {
buf := make([]byte, 1500)
n, addr, err := l.listener.ReadFromUDP(buf)
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
pConn, ok := l.conns[addr.String()]
if ok {
pConn.onNewMsg(buf[:n])
continue
}
pConn = NewConn(l.listener, addr)
log.Infof("new connection from: %s", pConn.RemoteAddr())
l.conns[addr.String()] = pConn
go l.onAcceptFn(pConn)
pConn.onNewMsg(buf[:n])
}
}

View File

@@ -0,0 +1,114 @@
package ws
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
)
const (
writeTimeout = 10 * time.Second
)
type Conn struct {
*websocket.Conn
lAddr *net.TCPAddr
rAddr *net.TCPAddr
closed bool
closedMu sync.Mutex
ctx context.Context
}
func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
return &Conn{
Conn: wsConn,
lAddr: lAddr,
rAddr: rAddr,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.Reader(c.ctx)
if err != nil {
return 0, c.ioErrHandling(err)
}
if t != websocket.MessageBinary {
log.Errorf("unexpected message type: %d", t)
return 0, fmt.Errorf("unexpected message type")
}
n, err = r.Read(b)
if err != nil {
return 0, c.ioErrHandling(err)
}
return n, err
}
// Write writes a binary message with the given payload.
// It does not block until fill the internal buffer.
// If the buffer filled up, wait until the buffer is drained or timeout.
func (c *Conn) Write(b []byte) (int, error) {
ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
defer ctxCancel()
err := c.Conn.Write(ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) LocalAddr() net.Addr {
return c.lAddr
}
func (c *Conn) RemoteAddr() net.Addr {
return c.rAddr
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
return c.Conn.CloseNow()
}
func (c *Conn) isClosed() bool {
c.closedMu.Lock()
defer c.closedMu.Unlock()
return c.closed
}
func (c *Conn) ioErrHandling(err error) error {
if c.isClosed() {
return io.EOF
}
var wErr *websocket.CloseError
if !errors.As(err, &wErr) {
return err
}
if wErr.Code == websocket.StatusNormalClosure {
return io.EOF
}
return err
}

View File

@@ -0,0 +1,88 @@
package ws
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"time"
log "github.com/sirupsen/logrus"
"nhooyr.io/websocket"
)
type Listener struct {
// Address is the address to listen on.
Address string
// TLSConfig is the TLS configuration for the server.
TLSConfig *tls.Config
server *http.Server
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
mux := http.NewServeMux()
mux.HandleFunc("/", l.onAccept)
l.server = &http.Server{
Addr: l.Address,
Handler: mux,
TLSConfig: l.TLSConfig,
}
log.Infof("WS server is listening on address: %s", l.Address)
var err error
if l.TLSConfig != nil {
err = l.server.ListenAndServeTLS("", "")
} else {
err = l.server.ListenAndServe()
}
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
func (l *Listener) Close() error {
if l.server == nil {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
log.Infof("stop WS listener")
if err := l.server.Shutdown(ctx); err != nil {
return fmt.Errorf("server shutdown failed: %v", err)
}
log.Infof("WS listener stopped")
return nil
}
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
wsConn, err := websocket.Accept(w, r, nil)
if err != nil {
log.Errorf("failed to accept ws connection: %s", err)
return
}
rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}

167
relay/server/peer.go Normal file
View File

@@ -0,0 +1,167 @@
package server
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 8820
)
type Peer struct {
log *log.Entry
idS string
idB []byte
conn net.Conn
connMu sync.RWMutex
store *Store
}
func NewPeer(id []byte, conn net.Conn, store *Store) *Peer {
stringID := messages.HashIDToString(id)
return &Peer{
log: log.WithField("peer_id", stringID),
idS: stringID,
idB: id,
conn: conn,
store: store,
}
}
func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background())
hc := healthcheck.NewSender(ctx)
go p.healthcheck(ctx, hc)
defer cancel()
buf := make([]byte, bufferSize)
for {
n, err := p.conn.Read(buf)
if err != nil {
if err != io.EOF {
p.log.Errorf("failed to read message: %s", err)
}
return
}
msg := buf[:n]
msgType, err := messages.DetermineClientMsgType(msg)
if err != nil {
p.log.Errorf("failed to determine message type: %s", err)
return
}
switch msgType {
case messages.MsgTypeHealthCheck:
hc.OnHCResponse()
case messages.MsgTypeTransport:
p.handleTransportMsg(msg)
case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully")
_ = p.conn.Close()
return
}
}
}
// Write writes data to the connection
// it has been called by the remote peer
func (p *Peer) Write(b []byte) (int, error) {
p.connMu.RLock()
defer p.connMu.RUnlock()
return p.conn.Write(b)
}
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
_, err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
log.Errorf("failed to send close message to peer: %s", p.String())
}
err = p.conn.Close()
if err != nil {
log.Errorf("failed to close connection to peer: %s", err)
}
defer p.connMu.Unlock()
}
func (p *Peer) String() string {
return p.idS
}
func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
writeDone := make(chan struct{})
var (
n int
err error
)
go func() {
_, err = p.conn.Write(buf)
close(writeDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("write operation timed out")
case <-writeDone:
return n, err
}
}
func (p *Peer) healthcheck(ctx context.Context, hc *healthcheck.Sender) {
for {
select {
case <-hc.HealthCheck:
_, err := p.Write(messages.MarshalHealthcheck())
if err != nil {
p.log.Errorf("failed to send healthcheck message: %s", err)
return
}
case <-hc.Timeout:
p.log.Errorf("peer healthcheck timeout")
_ = p.conn.Close()
return
case <-ctx.Done():
return
}
}
}
func (p *Peer) handleTransportMsg(msg []byte) {
peerID, err := messages.UnmarshalTransportID(msg)
if err != nil {
p.log.Errorf("failed to unmarshal transport message: %s", err)
return
}
stringPeerID := messages.HashIDToString(peerID)
dp, ok := p.store.Peer(stringPeerID)
if !ok {
p.log.Errorf("peer not found: %s", stringPeerID)
return
}
err = messages.UpdateTransportMsg(msg, p.idB)
if err != nil {
p.log.Errorf("failed to update transport message: %s", err)
return
}
_, err = dp.Write(msg)
if err != nil {
p.log.Errorf("failed to write transport message to: %s", dp.String())
}
}

123
relay/server/relay.go Normal file
View File

@@ -0,0 +1,123 @@
package server
import (
"context"
"fmt"
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/messages"
)
type Relay struct {
validator auth.Validator
store *Store
instaceURL string
closed bool
closeMu sync.RWMutex
}
func NewRelay(exposedAddress string, tlsSupport bool, validator auth.Validator) *Relay {
r := &Relay{
validator: validator,
store: NewStore(),
}
if tlsSupport {
r.instaceURL = fmt.Sprintf("rels://%s", exposedAddress)
} else {
r.instaceURL = fmt.Sprintf("rel://%s", exposedAddress)
}
return r
}
func (r *Relay) Accept(conn net.Conn) {
r.closeMu.RLock()
defer r.closeMu.RUnlock()
if r.closed {
return
}
peerID, err := r.handShake(conn)
if err != nil {
log.Errorf("failed to handshake with %s: %s", conn.RemoteAddr(), err)
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
}
return
}
peer := NewPeer(peerID, conn, r.store)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
r.store.AddPeer(peer)
go func() {
peer.Work()
r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed")
}()
}
func (r *Relay) Close(ctx context.Context) {
log.Infof("close connection with all peers")
r.closeMu.Lock()
wg := sync.WaitGroup{}
peers := r.store.Peers()
for _, peer := range peers {
wg.Add(1)
go func(p *Peer) {
p.CloseGracefully(ctx)
wg.Done()
}(peer)
}
wg.Wait()
r.closeMu.Unlock()
}
func (r *Relay) handShake(conn net.Conn) ([]byte, error) {
buf := make([]byte, messages.MaxHandshakeSize)
n, err := conn.Read(buf)
if err != nil {
log.Errorf("failed to read message: %s", err)
return nil, err
}
msgType, err := messages.DetermineClientMsgType(buf[:n])
if err != nil {
return nil, err
}
if msgType != messages.MsgTypeHello {
tErr := fmt.Errorf("invalid message type")
log.Errorf("failed to handshake: %s", tErr)
return nil, tErr
}
peerID, authPayload, err := messages.UnmarshalHelloMsg(buf[:n])
if err != nil {
log.Errorf("failed to handshake: %s", err)
return nil, err
}
if err := r.validator.Validate(authPayload); err != nil {
log.Errorf("failed to authenticate connection: %s", err)
return nil, err
}
msg, _ := messages.MarshalHelloResponse(r.instaceURL)
_, err = conn.Write(msg)
if err != nil {
return nil, err
}
return peerID, nil
}
func (r *Relay) InstanceURL() string {
return r.instaceURL
}

94
relay/server/server.go Normal file
View File

@@ -0,0 +1,94 @@
package server
import (
"context"
"crypto/tls"
"errors"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/auth"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/udp"
"github.com/netbirdio/netbird/relay/server/listener/ws"
)
type ListenerConfig struct {
Address string
TLSConfig *tls.Config
}
type Server struct {
relay *Relay
uDPListener listener.Listener
wSListener listener.Listener
}
func NewServer(exposedAddress string, tlsSupport bool, authValidator auth.Validator) *Server {
return &Server{
relay: NewRelay(
exposedAddress,
tlsSupport,
authValidator,
),
}
}
func (r *Server) Listen(cfg ListenerConfig) error {
wg := sync.WaitGroup{}
wg.Add(2)
r.wSListener = &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
}
var wslErr error
go func() {
defer wg.Done()
wslErr = r.wSListener.Listen(r.relay.Accept)
if wslErr != nil {
log.Errorf("failed to bind ws server: %s", wslErr)
}
}()
r.uDPListener = udp.NewListener(cfg.Address)
var udpLErr error
go func() {
defer wg.Done()
udpLErr = r.uDPListener.Listen(r.relay.Accept)
if udpLErr != nil {
log.Errorf("failed to bind ws server: %s", udpLErr)
}
}()
err := errors.Join(wslErr, udpLErr)
return err
}
func (r *Server) Close() error {
var wErr error
// stop service new connections
if r.wSListener != nil {
wErr = r.wSListener.Close()
}
var uErr error
if r.uDPListener != nil {
uErr = r.uDPListener.Close()
}
// close accepted connections gracefully
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
r.relay.Close(ctx)
err := errors.Join(wErr, uErr)
return err
}
func (r *Server) InstanceURL() string {
return r.relay.instaceURL
}

47
relay/server/store.go Normal file
View File

@@ -0,0 +1,47 @@
package server
import (
"sync"
)
type Store struct {
peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
peersLock sync.RWMutex
}
func NewStore() *Store {
return &Store{
peers: make(map[string]*Peer),
}
}
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
s.peers[peer.String()] = peer
}
func (s *Store) DeletePeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
delete(s.peers, peer.String())
}
func (s *Store) Peer(id string) (*Peer, bool) {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
p, ok := s.peers[id]
return p, ok
}
func (s *Store) Peers() []*Peer {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
peers := make([]*Peer, 0, len(s.peers))
for _, p := range s.peers {
peers = append(peers, p)
}
return peers
}

View File

@@ -51,11 +51,10 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
}
// MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey.String(),
RemoteKey: remoteKey,
Body: &proto.Body{
Type: t,
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
@@ -65,6 +64,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
RelayServerAddress: relaySrvAddress,
},
}, nil
}

View File

@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.12.4
// protoc v3.21.12
// source: signalexchange.proto
package proto
import (
_ "github.com/golang/protobuf/protoc-gen-go/descriptor"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
reflect "reflect"
sync "sync"
)
@@ -225,6 +225,8 @@ type Body struct {
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
// relayServerAddress is an IP:port of the relay server
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
}
func (x *Body) Reset() {
@@ -308,6 +310,13 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
return nil
}
func (x *Body) GetRelayServerAddress() string {
if x != nil {
return x.RelayServerAddress
}
return ""
}
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
@@ -431,7 +440,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa6, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -451,7 +460,10 @@ var file_signalexchange_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,

View File

@@ -60,6 +60,9 @@ message Body {
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig rosenpassConfig = 7;
// relayServerAddress is url of the relay server
string relayServerAddress = 8;
}
// Mode indicates a connection mode