[misc] Separate shared code dependencies (#4288)

* Separate shared code dependencies

* Fix import

* Test respective shared code

* Update openapi ref

* Fix test

* Fix test path
This commit is contained in:
Viktor Liu
2025-08-05 18:34:41 +02:00
committed by GitHub
parent 1d5e871bdf
commit abd152ee5a
150 changed files with 252 additions and 191 deletions

View File

@@ -1,17 +0,0 @@
/*
The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It
ensures that the connection between the client and the server are alive and functioning properly.
The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these
signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does
not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel
and stop working.
The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not
receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working.
In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is
closed and the peer is removed from the relay.
*/
package healthcheck

View File

@@ -1,94 +0,0 @@
package healthcheck
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
heartbeatTimeout = healthCheckInterval + 10*time.Second
)
// Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
OnTimeout chan struct{}
log *log.Entry
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
attemptThreshold int
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver(log *log.Entry) *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
OnTimeout: make(chan struct{}, 1),
log: log,
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
go r.waitForHealthcheck()
return r
}
// Heartbeat acknowledge the heartbeat has been received
func (r *Receiver) Heartbeat() {
select {
case r.heartbeat <- struct{}{}:
default:
}
}
// Stop check the timeout and do not send new notifications
func (r *Receiver) Stop() {
r.ctxCancel()
}
func (r *Receiver) waitForHealthcheck() {
ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop()
defer r.ctxCancel()
defer close(r.OnTimeout)
failureCounter := 0
for {
select {
case <-r.heartbeat:
r.alive = true
failureCounter = 0
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
failureCounter++
if failureCounter < r.attemptThreshold {
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
continue
}
r.notifyTimeout()
return
case <-r.ctx.Done():
return
}
}
}
func (r *Receiver) notifyTimeout() {
select {
case r.OnTimeout <- struct{}{}:
default:
}
}

View File

@@ -1,140 +0,0 @@
package healthcheck
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(1 * time.Second):
// Test passes if no timeout received
}
}
func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second):
t.Error("timeout not received")
}
}
func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(3 * time.Second):
}
}
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
case <-receiver.OnTimeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -1,110 +0,0 @@
package healthcheck
import (
"context"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 20 * time.Second
)
// Sender is a healthcheck sender
// It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
ack chan struct{}
alive bool
attemptThreshold int
}
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
hc := &Sender{
log: log,
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
}
return hc
}
// OnHCResponse sends an acknowledgment signal to the sender
func (hc *Sender) OnHCResponse() {
select {
case hc.ack <- struct{}{}:
default:
}
}
func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTicker.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
failureCounter := 0
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
case <-timeoutTicker.C:
if hc.alive {
hc.alive = false
continue
}
failureCounter++
if failureCounter < hc.attemptThreshold {
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
continue
}
hc.Timeout <- struct{}{}
return
case <-hc.ack:
failureCounter = 0
hc.alive = true
case <-ctx.Done():
return
}
}
}
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -1,213 +0,0 @@
package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestMain(m *testing.M) {
// override the health check interval to speed up the test
healthCheckInterval = 2 * time.Second
healthCheckTimeout = 100 * time.Millisecond
code := m.Run()
os.Exit(code)
}
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
for i := 0; i < 3; i++ {
select {
case <-hc.HealthCheck:
iterations++
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
}
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
select {
case <-hc.Timeout:
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out")
}
}
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
cancel()
select {
case _, ok := <-hc.HealthCheck:
if ok {
t.Fatalf("health check on received")
}
case _, ok := <-hc.Timeout:
if ok {
t.Fatalf("health check on received")
}
case <-ctx.Done():
// expected
case <-time.After(10 * time.Second):
t.Fatalf("is not exited")
}
}
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(log.WithContext(ctx))
go hc.StartHealthCheck(ctx)
iterations := 0
for i := 0; i < 3; i++ {
select {
case <-hc.HealthCheck:
iterations++
hc.OnHCResponse()
case <-hc.Timeout:
t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received")
}
}
select {
case <-hc.HealthCheck:
case <-hc.Timeout:
// expected
case <-ctx.Done():
t.Fatalf("context is done")
case <-time.After(10 * time.Second):
t.Fatalf("is not exited")
}
}
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go func() {
responded := false
for {
select {
case <-ctx.Done():
return
case _, ok := <-sender.HealthCheck:
if !ok {
return
}
if tc.resetCounterOnce && !responded {
responded = true
sender.OnHCResponse()
}
}
}
}()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
if tc.resetCounterOnce {
t.Errorf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Errorf("should have timed out before %s", testTimeout)
}
cancel()
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -1,21 +0,0 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package address
import (
"bytes"
"encoding/gob"
"fmt"
)
type Address struct {
URL string
}
func (addr *Address) Marshal() ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(addr); err != nil {
return nil, fmt.Errorf("encode Address: %w", err)
}
return buf.Bytes(), nil
}

View File

@@ -1,43 +0,0 @@
// Deprecated: This package is deprecated and will be removed in a future release.
package auth
import (
"bytes"
"encoding/gob"
"fmt"
)
type Algorithm int
const (
AlgoUnknown Algorithm = iota
AlgoHMACSHA256
AlgoHMACSHA512
)
func (a Algorithm) String() string {
switch a {
case AlgoHMACSHA256:
return "HMAC-SHA256"
case AlgoHMACSHA512:
return "HMAC-SHA512"
default:
return "Unknown"
}
}
type Msg struct {
AuthAlgorithm Algorithm
AdditionalData []byte
}
func UnmarshalMsg(data []byte) (*Msg, error) {
var msg *Msg
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&msg); err != nil {
return nil, fmt.Errorf("decode Msg: %w", err)
}
return msg, nil
}

View File

@@ -1,5 +0,0 @@
/*
Package messages provides the message types that are used to communicate between the relay and the client.
This package is used to determine the type of message that is being sent and received between the relay and the client.
*/
package messages

View File

@@ -1,31 +0,0 @@
package messages
import (
"crypto/sha256"
"encoding/base64"
"fmt"
)
const (
prefixLength = 4
peerIDSize = prefixLength + sha256.Size
)
var (
prefix = []byte("sha-") // 4 bytes
)
type PeerID [peerIDSize]byte
func (p PeerID) String() string {
return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
}
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
func HashID(peerID string) PeerID {
idHash := sha256.Sum256([]byte(peerID))
var prefixedHash [peerIDSize]byte
copy(prefixedHash[:prefixLength], prefix)
copy(prefixedHash[prefixLength:], idHash[:])
return prefixedHash
}

View File

@@ -1,337 +0,0 @@
package messages
import (
"bytes"
"errors"
"fmt"
)
const (
MaxHandshakeSize = 212
MaxHandshakeRespSize = 8192
MaxMessageSize = 8820
CurrentProtocolVersion = 1
MsgTypeUnknown MsgType = 0
// Deprecated: Use MsgTypeAuth instead.
MsgTypeHello = 1
// Deprecated: Use MsgTypeAuthResponse instead.
MsgTypeHelloResponse = 2
MsgTypeTransport = 3
MsgTypeClose = 4
MsgTypeHealthCheck = 5
MsgTypeAuth = 6
MsgTypeAuthResponse = 7
// Peers state messages
MsgTypeSubscribePeerState = 8
MsgTypeUnsubscribePeerState = 9
MsgTypePeersOnline = 10
MsgTypePeersWentOffline = 11
// base size of the message
sizeOfVersionByte = 1
sizeOfMsgType = 1
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
// auth message
sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + peerIDSize
offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
// hello message
headerSizeHello = sizeOfMagicByte + peerIDSize
headerSizeHelloResp = 0
// transport
headerSizeTransport = peerIDSize
offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
)
var (
ErrInvalidMessageLength = errors.New("invalid message length")
ErrUnsupportedVersion = errors.New("unsupported version")
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
)
type MsgType byte
func (m MsgType) String() string {
switch m {
case MsgTypeHello:
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeAuth:
return "auth"
case MsgTypeAuthResponse:
return "auth response"
case MsgTypeTransport:
return "transport"
case MsgTypeClose:
return "close"
case MsgTypeHealthCheck:
return "health check"
case MsgTypeSubscribePeerState:
return "subscribe peer state"
case MsgTypeUnsubscribePeerState:
return "unsubscribe peer state"
case MsgTypePeersOnline:
return "peers online"
case MsgTypePeersWentOffline:
return "peers went offline"
default:
return "unknown"
}
}
// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
version := int(msg[0])
if version != CurrentProtocolVersion {
return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
}
return version, nil
}
// DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHello,
MsgTypeAuth,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck,
MsgTypeSubscribePeerState,
MsgTypeUnsubscribePeerState:
return msgType, nil
default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
}
}
// DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHelloResponse,
MsgTypeAuthResponse,
MsgTypeTransport,
MsgTypeClose,
MsgTypeHealthCheck,
MsgTypePeersOnline,
MsgTypePeersWentOffline:
return msgType, nil
default:
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
}
}
// Deprecated: Use MarshalAuthMsg instead.
// MarshalHelloMsg initial hello message
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
msg = append(msg, peerID[:]...)
msg = append(msg, additions...)
return msg, nil
}
// Deprecated: Use UnmarshalAuthMsg instead.
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
return &peerID, msg[headerSizeHello:], nil
}
// Deprecated: Use MarshalAuthResponse instead.
// MarshalHelloResponse creates a response message to the hello message.
// In case of success connection the server response with a Hello Response message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
msg = append(msg, additionalData...)
return msg, nil
}
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
return nil, ErrInvalidMessageLength
}
return msg, nil
}
// MarshalAuthMsg initial authentication message
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
// close the network connection without any response.
func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
return nil, fmt.Errorf("too large auth payload")
}
msg := make([]byte, headerTotalSizeAuth+len(authPayload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)
copy(msg[sizeOfProtoHeader:], magicHeader)
copy(msg[offsetAuthPeerID:], peerID[:])
copy(msg[headerTotalSizeAuth:], authPayload)
return msg, nil
}
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
// Validate the magic header
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}
peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
return &peerID, msg[headerTotalSizeAuth:], nil
}
// MarshalAuthResponse creates a response message to the auth.
// In case of success connection the server response with a AuthResponse message. This message contains the server's
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
msg = append(msg, ab...)
if len(msg) > MaxHandshakeRespSize {
return nil, fmt.Errorf("invalid message length: %d", len(msg))
}
return msg, nil
}
// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < sizeOfProtoHeader+1 {
return "", ErrInvalidMessageLength
}
return string(msg[sizeOfProtoHeader:]), nil
}
// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte {
return []byte{
byte(CurrentProtocolVersion),
byte(MsgTypeClose),
}
}
// MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID.
func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
// todo validate size
msg := make([]byte, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport)
copy(msg[sizeOfProtoHeader:], peerID[:])
copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
return msg, nil
}
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength
}
const offsetEnd = offsetTransportID + peerIDSize
var peerID PeerID
copy(peerID[:], buf[offsetTransportID:offsetEnd])
return &peerID, buf[headerTotalSizeTransport:], nil
}
// UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) (*PeerID, error) {
if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength
}
const offsetEnd = offsetTransportID + peerIDSize
var id PeerID
copy(id[:], buf[offsetTransportID:offsetEnd])
return &id, nil
}
// UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID PeerID) error {
if len(msg) < offsetTransportID+peerIDSize {
return ErrInvalidMessageLength
}
copy(msg[offsetTransportID:], peerID[:])
return nil
}
// MarshalHealthcheck creates a health check message.
// Health check message is sent by the server periodically. The client will respond with a health check response
// message. If the client does not respond to the health check message, the server will close the connection.
func MarshalHealthcheck() []byte {
return healthCheckMsg
}

View File

@@ -1,138 +0,0 @@
package messages
import (
"testing"
)
func TestMarshalHelloMsg(t *testing.T) {
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalHelloMsg(peerID, nil)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeHello {
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
}
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalAuthMsg(t *testing.T) {
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
msg, err := MarshalAuthMsg(peerID, []byte{})
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeAuth {
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
}
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if receivedPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
}
}
func TestMarshalAuthResponse(t *testing.T) {
address := "myaddress"
msg, err := MarshalAuthResponse(address)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineServerMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeAuthResponse {
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
}
respAddr, err := UnmarshalAuthResponse(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if respAddr != address {
t.Errorf("expected %s, got %s", address, respAddr)
}
}
func TestMarshalTransportMsg(t *testing.T) {
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
payload := []byte("payload")
msg, err := MarshalTransportMsg(peerID, payload)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineClientMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeTransport {
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
}
uPeerID, err := UnmarshalTransportID(msg)
if err != nil {
t.Fatalf("failed to unmarshal transport id: %v", err)
}
if uPeerID.String() != peerID.String() {
t.Errorf("expected %s, got %s", peerID, uPeerID)
}
id, respPayload, err := UnmarshalTransportMsg(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if id.String() != peerID.String() {
t.Errorf("expected: '%s', got: '%s'", peerID, id)
}
if string(respPayload) != string(payload) {
t.Errorf("expected %s, got %s", payload, respPayload)
}
}
func TestMarshalHealthcheck(t *testing.T) {
msg := MarshalHealthcheck()
_, err := ValidateVersion(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
msgType, err := DetermineServerMessageType(msg)
if err != nil {
t.Fatalf("error: %v", err)
}
if msgType != MsgTypeHealthCheck {
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
}
}

View File

@@ -1,92 +0,0 @@
package messages
import (
"fmt"
)
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
}
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
}
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
}
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
}
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
return unmarshalPeerIDs(buf)
}
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
if len(ids) == 0 {
return nil, fmt.Errorf("no list of peer ids provided")
}
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
var messages [][]byte
for i := 0; i < len(ids); i += maxPeersPerMessage {
end := i + maxPeersPerMessage
if end > len(ids) {
end = len(ids)
}
chunk := ids[i:end]
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
buf := make([]byte, totalSize)
buf[0] = byte(CurrentProtocolVersion)
buf[1] = msgType
offset := sizeOfProtoHeader
for _, id := range chunk {
copy(buf[offset:], id[:])
offset += peerIDSize
}
messages = append(messages, buf)
}
return messages, nil
}
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
if len(buf) < sizeOfProtoHeader {
return nil, fmt.Errorf("invalid message format")
}
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
}
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
ids := make([]PeerID, numIDs)
offset := sizeOfProtoHeader
for i := 0; i < numIDs; i++ {
copy(ids[i][:], buf[offset:offset+peerIDSize])
offset += peerIDSize
}
return ids, nil
}

View File

@@ -1,144 +0,0 @@
package messages
import (
"bytes"
"testing"
)
const (
testPeerCount = 10
)
// Helper function to generate test PeerIDs
func generateTestPeerIDs(n int) []PeerID {
ids := make([]PeerID, n)
for i := 0; i < n; i++ {
for j := 0; j < peerIDSize; j++ {
ids[i][j] = byte(i + j)
}
}
return ids
}
// Helper function to compare slices of PeerID
func peerIDEqual(a, b []PeerID) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !bytes.Equal(a[i][:], b[i][:]) {
return false
}
}
return true
}
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalSubPeerStateMsg(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalSubPeerStateMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
_, err := MarshalSubPeerStateMsg([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
// Too short
_, err := UnmarshalSubPeerStateMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
// Misaligned length
buf := make([]byte, sizeOfProtoHeader+1)
_, err = UnmarshalSubPeerStateMsg(buf)
if err == nil {
t.Errorf("expected error for misaligned input")
}
}
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersOnline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
_, err := MarshalPeersOnline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
_, err := UnmarshalPeersOnlineMsg([]byte{1})
if err == nil {
t.Errorf("expected error for short input")
}
}
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
ids := generateTestPeerIDs(testPeerCount)
msgs, err := MarshalPeersWentOffline(ids)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var allIDs []PeerID
for _, msg := range msgs {
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
decoded, err := UnmarshalPeersOnlineMsg(msg)
if err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
allIDs = append(allIDs, decoded...)
}
if !peerIDEqual(ids, allIDs) {
t.Errorf("expected %v, got %v", ids, allIDs)
}
}
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
_, err := MarshalPeersWentOffline([]PeerID{})
if err == nil {
t.Errorf("expected error for empty input")
}
}

View File

@@ -6,11 +6,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/shared/relay/messages"
//nolint:staticcheck
"github.com/netbirdio/netbird/relay/messages/address"
"github.com/netbirdio/netbird/shared/relay/messages/address"
//nolint:staticcheck
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
)
type Validator interface {

View File

@@ -10,10 +10,12 @@ import (
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/relay"
)
// URLPath is the path for the websocket connection.
const URLPath = "/relay"
const URLPath = relay.WebSocketURLPath
type Listener struct {
// Address is the address to listen on.

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/shared/relay/healthcheck"
"github.com/netbirdio/netbird/shared/relay/messages"
"github.com/netbirdio/netbird/relay/metrics"
"github.com/netbirdio/netbird/relay/server/store"
)

View File

@@ -10,7 +10,7 @@ import (
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/relay/tls"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
log "github.com/sirupsen/logrus"
)

View File

@@ -4,7 +4,7 @@ import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/shared/relay/messages"
)
type event struct {

View File

@@ -4,7 +4,7 @@ import (
"context"
"sync"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/shared/relay/messages"
)
type PeerNotifier struct {

View File

@@ -3,7 +3,7 @@ package store
import (
"sync"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/shared/relay/messages"
)
type IPeer interface {

View File

@@ -3,7 +3,7 @@ package store
import (
"testing"
"github.com/netbirdio/netbird/relay/messages"
"github.com/netbirdio/netbird/shared/relay/messages"
)
type MocPeer struct {

View File

@@ -1,3 +0,0 @@
package tls
const nbalpn = "nb-quic"

View File

@@ -1,26 +0,0 @@
//go:build devcert
package tls
import (
"crypto/tls"
"crypto/x509"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util/embeddedroots"
)
func ClientQUICTLSConfig() *tls.Config {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}
return &tls.Config{
InsecureSkipVerify: true, // Debug mode allows insecure connections
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
RootCAs: certPool,
}
}

View File

@@ -1,25 +0,0 @@
//go:build !devcert
package tls
import (
"crypto/tls"
"crypto/x509"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util/embeddedroots"
)
func ClientQUICTLSConfig() *tls.Config {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}
return &tls.Config{
NextProtos: []string{nbalpn},
RootCAs: certPool,
}
}

View File

@@ -1,36 +0,0 @@
// Package tls provides utilities for configuring and managing Transport Layer
// Security (TLS) in server and client environments, with a focus on QUIC
// protocol support and testing configurations.
//
// The package includes functions for cloning and customizing TLS
// configurations as well as generating self-signed certificates for
// development and testing purposes.
//
// Key Features:
//
// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored
// for QUIC protocol with specified or default settings. QUIC requires a
// specific TLS configuration with proper ALPN (Application-Layer Protocol
// Negotiation) support, making the TLS settings crucial for establishing
// secure connections.
//
// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable
// for QUIC protocol. The configuration differs between development
// (insecure testing) and production (strict verification).
//
// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for
// use in local development and testing scenarios.
//
// Usage:
//
// This package provides separate implementations for development and production
// environments. The development implementation (guarded by `//go:build devcert`)
// supports testing configurations with self-signed certificates and insecure
// client connections. The production implementation (guarded by `//go:build
// !devcert`) ensures that valid and secure TLS configurations are supplied
// and used.
//
// The QUIC protocol is highly reliant on properly configured TLS settings,
// and this package ensures that configurations meet the requirements for
// secure and efficient QUIC communication.
package tls

View File

@@ -1,79 +0,0 @@
//go:build devcert
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
log "github.com/sirupsen/logrus"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
log.Warnf("QUIC server will use self signed certificate for testing!")
return generateTestTLSConfig()
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}
// GenerateTestTLSConfig creates a self-signed certificate for testing
func generateTestTLSConfig() (*tls.Config, error) {
log.Infof("generating test TLS config")
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{nbalpn},
}, nil
}

View File

@@ -1,17 +0,0 @@
//go:build !devcert
package tls
import (
"crypto/tls"
"fmt"
)
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
if originTLSCfg == nil {
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
return cfg, nil
}