mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[misc] Separate shared code dependencies (#4288)
* Separate shared code dependencies * Fix import * Test respective shared code * Update openapi ref * Fix test * Fix test path
This commit is contained in:
@@ -1,17 +0,0 @@
|
||||
/*
|
||||
The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It
|
||||
ensures that the connection between the client and the server are alive and functioning properly.
|
||||
|
||||
The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these
|
||||
signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does
|
||||
not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel
|
||||
and stop working.
|
||||
|
||||
The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not
|
||||
receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working.
|
||||
|
||||
In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is
|
||||
closed and the peer is removed from the relay.
|
||||
*/
|
||||
|
||||
package healthcheck
|
||||
@@ -1,94 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
heartbeatTimeout = healthCheckInterval + 10*time.Second
|
||||
)
|
||||
|
||||
// Receiver is a healthcheck receiver
|
||||
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
|
||||
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
|
||||
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
|
||||
type Receiver struct {
|
||||
OnTimeout chan struct{}
|
||||
log *log.Entry
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
heartbeat chan struct{}
|
||||
alive bool
|
||||
attemptThreshold int
|
||||
}
|
||||
|
||||
// NewReceiver creates a new healthcheck receiver and start the timer in the background
|
||||
func NewReceiver(log *log.Entry) *Receiver {
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
r := &Receiver{
|
||||
OnTimeout: make(chan struct{}, 1),
|
||||
log: log,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
heartbeat: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
|
||||
go r.waitForHealthcheck()
|
||||
return r
|
||||
}
|
||||
|
||||
// Heartbeat acknowledge the heartbeat has been received
|
||||
func (r *Receiver) Heartbeat() {
|
||||
select {
|
||||
case r.heartbeat <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Stop check the timeout and do not send new notifications
|
||||
func (r *Receiver) Stop() {
|
||||
r.ctxCancel()
|
||||
}
|
||||
|
||||
func (r *Receiver) waitForHealthcheck() {
|
||||
ticker := time.NewTicker(heartbeatTimeout)
|
||||
defer ticker.Stop()
|
||||
defer r.ctxCancel()
|
||||
defer close(r.OnTimeout)
|
||||
|
||||
failureCounter := 0
|
||||
for {
|
||||
select {
|
||||
case <-r.heartbeat:
|
||||
r.alive = true
|
||||
failureCounter = 0
|
||||
case <-ticker.C:
|
||||
if r.alive {
|
||||
r.alive = false
|
||||
continue
|
||||
}
|
||||
|
||||
failureCounter++
|
||||
if failureCounter < r.attemptThreshold {
|
||||
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
|
||||
continue
|
||||
}
|
||||
r.notifyTimeout()
|
||||
return
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Receiver) notifyTimeout() {
|
||||
select {
|
||||
case r.OnTimeout <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Mutex to protect global variable access in tests
|
||||
var testMutex sync.Mutex
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
t.Error("unexpected timeout")
|
||||
case <-time.After(1 * time.Second):
|
||||
// Test passes if no timeout received
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
// Test passes if timeout is received
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout not received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalTimeout := heartbeatTimeout
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
defer r.Stop()
|
||||
|
||||
r.Heartbeat()
|
||||
|
||||
select {
|
||||
case <-r.OnTimeout:
|
||||
t.Error("unexpected timeout")
|
||||
case <-time.After(3 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
testsCases := []struct {
|
||||
name string
|
||||
threshold int
|
||||
resetCounterOnce bool
|
||||
}{
|
||||
{"Default attempt threshold", defaultAttemptThreshold, false},
|
||||
{"Custom attempt threshold", 3, false},
|
||||
{"Should reset threshold once", 2, true},
|
||||
}
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testMutex.Lock()
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := heartbeatTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||
testMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
testMutex.Lock()
|
||||
healthCheckInterval = originalInterval
|
||||
heartbeatTimeout = originalTimeout
|
||||
testMutex.Unlock()
|
||||
}()
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
|
||||
receiver := NewReceiver(log.WithField("test_name", tc.name))
|
||||
|
||||
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
if tc.resetCounterOnce {
|
||||
receiver.Heartbeat()
|
||||
t.Logf("reset counter once")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiver.OnTimeout:
|
||||
if tc.resetCounterOnce {
|
||||
t.Fatalf("should not have timed out before %s", testTimeout)
|
||||
}
|
||||
case <-time.After(testTimeout):
|
||||
if tc.resetCounterOnce {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAttemptThreshold = 1
|
||||
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
|
||||
)
|
||||
|
||||
var (
|
||||
healthCheckInterval = 25 * time.Second
|
||||
healthCheckTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
// Sender is a healthcheck sender
|
||||
// It will send healthcheck signal to the receiver
|
||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||
// It will also stop if the context is canceled
|
||||
type Sender struct {
|
||||
log *log.Entry
|
||||
// HealthCheck is a channel to send health check signal to the peer
|
||||
HealthCheck chan struct{}
|
||||
// Timeout is a channel to the health check signal is not received in a certain time
|
||||
Timeout chan struct{}
|
||||
|
||||
ack chan struct{}
|
||||
alive bool
|
||||
attemptThreshold int
|
||||
}
|
||||
|
||||
// NewSender creates a new healthcheck sender
|
||||
func NewSender(log *log.Entry) *Sender {
|
||||
hc := &Sender{
|
||||
log: log,
|
||||
HealthCheck: make(chan struct{}, 1),
|
||||
Timeout: make(chan struct{}, 1),
|
||||
ack: make(chan struct{}, 1),
|
||||
attemptThreshold: getAttemptThresholdFromEnv(),
|
||||
}
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// OnHCResponse sends an acknowledgment signal to the sender
|
||||
func (hc *Sender) OnHCResponse() {
|
||||
select {
|
||||
case hc.ack <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
ticker := time.NewTicker(healthCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
|
||||
defer timeoutTicker.Stop()
|
||||
|
||||
defer close(hc.HealthCheck)
|
||||
defer close(hc.Timeout)
|
||||
|
||||
failureCounter := 0
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
hc.HealthCheck <- struct{}{}
|
||||
case <-timeoutTicker.C:
|
||||
if hc.alive {
|
||||
hc.alive = false
|
||||
continue
|
||||
}
|
||||
|
||||
failureCounter++
|
||||
if failureCounter < hc.attemptThreshold {
|
||||
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
|
||||
continue
|
||||
}
|
||||
hc.Timeout <- struct{}{}
|
||||
return
|
||||
case <-hc.ack:
|
||||
failureCounter = 0
|
||||
hc.alive = true
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *Sender) getTimeoutTime() time.Duration {
|
||||
return healthCheckInterval + healthCheckTimeout
|
||||
}
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
|
||||
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
return int(threshold)
|
||||
}
|
||||
return defaultAttemptThreshold
|
||||
}
|
||||
@@ -1,213 +0,0 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// override the health check interval to speed up the test
|
||||
healthCheckInterval = 2 * time.Second
|
||||
healthCheckTimeout = 100 * time.Millisecond
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestNewHealthPeriod(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
iterations++
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHealthFailed(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
select {
|
||||
case <-hc.Timeout:
|
||||
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
|
||||
t.Fatalf("health check is not timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHealthcheckStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case _, ok := <-hc.HealthCheck:
|
||||
if ok {
|
||||
t.Fatalf("health check on received")
|
||||
}
|
||||
case _, ok := <-hc.Timeout:
|
||||
if ok {
|
||||
t.Fatalf("health check on received")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
// expected
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("is not exited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutReset(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
hc := NewSender(log.WithContext(ctx))
|
||||
go hc.StartHealthCheck(ctx)
|
||||
|
||||
iterations := 0
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
iterations++
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-hc.HealthCheck:
|
||||
case <-hc.Timeout:
|
||||
// expected
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("context is done")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("is not exited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
testsCases := []struct {
|
||||
name string
|
||||
threshold int
|
||||
resetCounterOnce bool
|
||||
}{
|
||||
{"Default attempt threshold", defaultAttemptThreshold, false},
|
||||
{"Custom attempt threshold", 3, false},
|
||||
{"Should reset threshold once", 2, true},
|
||||
}
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := healthCheckTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
healthCheckTimeout = 500 * time.Millisecond
|
||||
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
defer os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sender := NewSender(log.WithField("test_name", tc.name))
|
||||
senderExit := make(chan struct{})
|
||||
go func() {
|
||||
sender.StartHealthCheck(ctx)
|
||||
close(senderExit)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
responded := false
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case _, ok := <-sender.HealthCheck:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if tc.resetCounterOnce && !responded {
|
||||
responded = true
|
||||
sender.OnHCResponse()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
|
||||
|
||||
select {
|
||||
case <-sender.Timeout:
|
||||
if tc.resetCounterOnce {
|
||||
t.Errorf("should not have timed out before %s", testTimeout)
|
||||
}
|
||||
case <-time.After(testTimeout):
|
||||
if tc.resetCounterOnce {
|
||||
return
|
||||
}
|
||||
t.Errorf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-senderExit:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("sender did not exit in time")
|
||||
}
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//nolint:tenv
|
||||
func TestGetAttemptThresholdFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
|
||||
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
|
||||
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.envValue == "" {
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
} else {
|
||||
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
|
||||
}
|
||||
|
||||
result := getAttemptThresholdFromEnv()
|
||||
if result != tt.expected {
|
||||
t.Fatalf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
|
||||
os.Unsetenv(defaultAttemptThresholdEnv)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||
package address
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Address struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
func (addr *Address) Marshal() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(addr); err != nil {
|
||||
return nil, fmt.Errorf("encode Address: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
// Deprecated: This package is deprecated and will be removed in a future release.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Algorithm int
|
||||
|
||||
const (
|
||||
AlgoUnknown Algorithm = iota
|
||||
AlgoHMACSHA256
|
||||
AlgoHMACSHA512
|
||||
)
|
||||
|
||||
func (a Algorithm) String() string {
|
||||
switch a {
|
||||
case AlgoHMACSHA256:
|
||||
return "HMAC-SHA256"
|
||||
case AlgoHMACSHA512:
|
||||
return "HMAC-SHA512"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type Msg struct {
|
||||
AuthAlgorithm Algorithm
|
||||
AdditionalData []byte
|
||||
}
|
||||
|
||||
func UnmarshalMsg(data []byte) (*Msg, error) {
|
||||
var msg *Msg
|
||||
|
||||
buf := bytes.NewBuffer(data)
|
||||
dec := gob.NewDecoder(buf)
|
||||
if err := dec.Decode(&msg); err != nil {
|
||||
return nil, fmt.Errorf("decode Msg: %w", err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
/*
|
||||
Package messages provides the message types that are used to communicate between the relay and the client.
|
||||
This package is used to determine the type of message that is being sent and received between the relay and the client.
|
||||
*/
|
||||
package messages
|
||||
@@ -1,31 +0,0 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixLength = 4
|
||||
peerIDSize = prefixLength + sha256.Size
|
||||
)
|
||||
|
||||
var (
|
||||
prefix = []byte("sha-") // 4 bytes
|
||||
)
|
||||
|
||||
type PeerID [peerIDSize]byte
|
||||
|
||||
func (p PeerID) String() string {
|
||||
return fmt.Sprintf("%s%s", p[:prefixLength], base64.StdEncoding.EncodeToString(p[prefixLength:]))
|
||||
}
|
||||
|
||||
// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
|
||||
func HashID(peerID string) PeerID {
|
||||
idHash := sha256.Sum256([]byte(peerID))
|
||||
var prefixedHash [peerIDSize]byte
|
||||
copy(prefixedHash[:prefixLength], prefix)
|
||||
copy(prefixedHash[prefixLength:], idHash[:])
|
||||
return prefixedHash
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxHandshakeSize = 212
|
||||
MaxHandshakeRespSize = 8192
|
||||
MaxMessageSize = 8820
|
||||
|
||||
CurrentProtocolVersion = 1
|
||||
|
||||
MsgTypeUnknown MsgType = 0
|
||||
// Deprecated: Use MsgTypeAuth instead.
|
||||
MsgTypeHello = 1
|
||||
// Deprecated: Use MsgTypeAuthResponse instead.
|
||||
MsgTypeHelloResponse = 2
|
||||
MsgTypeTransport = 3
|
||||
MsgTypeClose = 4
|
||||
MsgTypeHealthCheck = 5
|
||||
MsgTypeAuth = 6
|
||||
MsgTypeAuthResponse = 7
|
||||
|
||||
// Peers state messages
|
||||
MsgTypeSubscribePeerState = 8
|
||||
MsgTypeUnsubscribePeerState = 9
|
||||
MsgTypePeersOnline = 10
|
||||
MsgTypePeersWentOffline = 11
|
||||
|
||||
// base size of the message
|
||||
sizeOfVersionByte = 1
|
||||
sizeOfMsgType = 1
|
||||
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType
|
||||
|
||||
// auth message
|
||||
sizeOfMagicByte = 4
|
||||
headerSizeAuth = sizeOfMagicByte + peerIDSize
|
||||
offsetMagicByte = sizeOfProtoHeader
|
||||
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
|
||||
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth
|
||||
|
||||
// hello message
|
||||
headerSizeHello = sizeOfMagicByte + peerIDSize
|
||||
headerSizeHelloResp = 0
|
||||
|
||||
// transport
|
||||
headerSizeTransport = peerIDSize
|
||||
offsetTransportID = sizeOfProtoHeader
|
||||
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessageLength = errors.New("invalid message length")
|
||||
ErrUnsupportedVersion = errors.New("unsupported version")
|
||||
|
||||
magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
|
||||
|
||||
healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
|
||||
)
|
||||
|
||||
type MsgType byte
|
||||
|
||||
func (m MsgType) String() string {
|
||||
switch m {
|
||||
case MsgTypeHello:
|
||||
return "hello"
|
||||
case MsgTypeHelloResponse:
|
||||
return "hello response"
|
||||
case MsgTypeAuth:
|
||||
return "auth"
|
||||
case MsgTypeAuthResponse:
|
||||
return "auth response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
case MsgTypeClose:
|
||||
return "close"
|
||||
case MsgTypeHealthCheck:
|
||||
return "health check"
|
||||
case MsgTypeSubscribePeerState:
|
||||
return "subscribe peer state"
|
||||
case MsgTypeUnsubscribePeerState:
|
||||
return "unsubscribe peer state"
|
||||
case MsgTypePeersOnline:
|
||||
return "peers online"
|
||||
case MsgTypePeersWentOffline:
|
||||
return "peers went offline"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateVersion checks if the given version is supported by the protocol
|
||||
func ValidateVersion(msg []byte) (int, error) {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
version := int(msg[0])
|
||||
if version != CurrentProtocolVersion {
|
||||
return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// DetermineClientMessageType determines the message type from the first the message
|
||||
func DetermineClientMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHello,
|
||||
MsgTypeAuth,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck,
|
||||
MsgTypeSubscribePeerState,
|
||||
MsgTypeUnsubscribePeerState:
|
||||
return msgType, nil
|
||||
default:
|
||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// DetermineServerMessageType determines the message type from the first the message
|
||||
func DetermineServerMessageType(msg []byte) (MsgType, error) {
|
||||
if len(msg) < sizeOfProtoHeader {
|
||||
return 0, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
msgType := MsgType(msg[1])
|
||||
switch msgType {
|
||||
case
|
||||
MsgTypeHelloResponse,
|
||||
MsgTypeAuthResponse,
|
||||
MsgTypeTransport,
|
||||
MsgTypeClose,
|
||||
MsgTypeHealthCheck,
|
||||
MsgTypePeersOnline,
|
||||
MsgTypePeersWentOffline:
|
||||
return msgType, nil
|
||||
default:
|
||||
return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthMsg instead.
|
||||
// MarshalHelloMsg initial hello message
|
||||
// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||
// close the network connection without any response.
|
||||
func MarshalHelloMsg(peerID PeerID, additions []byte) ([]byte, error) {
|
||||
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHello)
|
||||
|
||||
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)
|
||||
|
||||
msg = append(msg, peerID[:]...)
|
||||
msg = append(msg, additions...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Deprecated: Use UnmarshalAuthMsg instead.
|
||||
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
|
||||
// authenticate the client with the server.
|
||||
func UnmarshalHelloMsg(msg []byte) (*PeerID, []byte, error) {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHello {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
peerID := PeerID(msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello])
|
||||
|
||||
return &peerID, msg[headerSizeHello:], nil
|
||||
}
|
||||
|
||||
// Deprecated: Use MarshalAuthResponse instead.
|
||||
// MarshalHelloResponse creates a response message to the hello message.
|
||||
// In case of success connection the server response with a Hello Response message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeHelloResponse)
|
||||
|
||||
msg = append(msg, additionalData...)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Deprecated: Use UnmarshalAuthResponse instead.
|
||||
// UnmarshalHelloResponse extracts the additional data from the hello response message.
|
||||
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
|
||||
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// MarshalAuthMsg initial authentication message
|
||||
// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
|
||||
// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
|
||||
// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
|
||||
// close the network connection without any response.
|
||||
func MarshalAuthMsg(peerID PeerID, authPayload []byte) ([]byte, error) {
|
||||
if headerTotalSizeAuth+len(authPayload) > MaxHandshakeSize {
|
||||
return nil, fmt.Errorf("too large auth payload")
|
||||
}
|
||||
|
||||
msg := make([]byte, headerTotalSizeAuth+len(authPayload))
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuth)
|
||||
copy(msg[sizeOfProtoHeader:], magicHeader)
|
||||
copy(msg[offsetAuthPeerID:], peerID[:])
|
||||
copy(msg[headerTotalSizeAuth:], authPayload)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalAuthMsg extracts peerID and the auth payload from the message
|
||||
func UnmarshalAuthMsg(msg []byte) (*PeerID, []byte, error) {
|
||||
if len(msg) < headerTotalSizeAuth {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
// Validate the magic header
|
||||
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
|
||||
return nil, nil, errors.New("invalid magic header")
|
||||
}
|
||||
|
||||
peerID := PeerID(msg[offsetAuthPeerID:headerTotalSizeAuth])
|
||||
return &peerID, msg[headerTotalSizeAuth:], nil
|
||||
}
|
||||
|
||||
// MarshalAuthResponse creates a response message to the auth.
|
||||
// In case of success connection the server response with a AuthResponse message. This message contains the server's
|
||||
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
|
||||
// servers.
|
||||
func MarshalAuthResponse(address string) ([]byte, error) {
|
||||
ab := []byte(address)
|
||||
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))
|
||||
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeAuthResponse)
|
||||
|
||||
msg = append(msg, ab...)
|
||||
|
||||
if len(msg) > MaxHandshakeRespSize {
|
||||
return nil, fmt.Errorf("invalid message length: %d", len(msg))
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalAuthResponse it is a confirmation message to auth success
|
||||
func UnmarshalAuthResponse(msg []byte) (string, error) {
|
||||
if len(msg) < sizeOfProtoHeader+1 {
|
||||
return "", ErrInvalidMessageLength
|
||||
}
|
||||
return string(msg[sizeOfProtoHeader:]), nil
|
||||
}
|
||||
|
||||
// MarshalCloseMsg creates a close message.
|
||||
// The close message is used to close the connection gracefully between the client and the server. The server and the
|
||||
// client can send this message. After receiving this message, the server or client will close the connection.
|
||||
func MarshalCloseMsg() []byte {
|
||||
return []byte{
|
||||
byte(CurrentProtocolVersion),
|
||||
byte(MsgTypeClose),
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalTransportMsg creates a transport message.
|
||||
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
|
||||
// destination peer hashed ID.
|
||||
func MarshalTransportMsg(peerID PeerID, payload []byte) ([]byte, error) {
|
||||
// todo validate size
|
||||
msg := make([]byte, headerTotalSizeTransport+len(payload))
|
||||
msg[0] = byte(CurrentProtocolVersion)
|
||||
msg[1] = byte(MsgTypeTransport)
|
||||
copy(msg[sizeOfProtoHeader:], peerID[:])
|
||||
copy(msg[sizeOfProtoHeader+peerIDSize:], payload)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
|
||||
func UnmarshalTransportMsg(buf []byte) (*PeerID, []byte, error) {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
const offsetEnd = offsetTransportID + peerIDSize
|
||||
var peerID PeerID
|
||||
copy(peerID[:], buf[offsetTransportID:offsetEnd])
|
||||
return &peerID, buf[headerTotalSizeTransport:], nil
|
||||
}
|
||||
|
||||
// UnmarshalTransportID extracts the peerID from the transport message.
|
||||
func UnmarshalTransportID(buf []byte) (*PeerID, error) {
|
||||
if len(buf) < headerTotalSizeTransport {
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
const offsetEnd = offsetTransportID + peerIDSize
|
||||
var id PeerID
|
||||
copy(id[:], buf[offsetTransportID:offsetEnd])
|
||||
return &id, nil
|
||||
}
|
||||
|
||||
// UpdateTransportMsg updates the peerID in the transport message.
|
||||
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
|
||||
// need to allocate a new byte slice.
|
||||
func UpdateTransportMsg(msg []byte, peerID PeerID) error {
|
||||
if len(msg) < offsetTransportID+peerIDSize {
|
||||
return ErrInvalidMessageLength
|
||||
}
|
||||
copy(msg[offsetTransportID:], peerID[:])
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalHealthcheck creates a health check message.
|
||||
// Health check message is sent by the server periodically. The client will respond with a health check response
|
||||
// message. If the client does not respond to the health check message, the server will close the connection.
|
||||
func MarshalHealthcheck() []byte {
|
||||
return healthCheckMsg
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMarshalHelloMsg(t *testing.T) {
|
||||
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
msg, err := MarshalHelloMsg(peerID, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeHello {
|
||||
t.Errorf("expected %d, got %d", MsgTypeHello, msgType)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalHelloMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if receivedPeerID.String() != peerID.String() {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthMsg(t *testing.T) {
|
||||
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
msg, err := MarshalAuthMsg(peerID, []byte{})
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeAuth {
|
||||
t.Errorf("expected %d, got %d", MsgTypeAuth, msgType)
|
||||
}
|
||||
|
||||
receivedPeerID, _, err := UnmarshalAuthMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if receivedPeerID.String() != peerID.String() {
|
||||
t.Errorf("expected %s, got %s", peerID, receivedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalAuthResponse(t *testing.T) {
|
||||
address := "myaddress"
|
||||
msg, err := MarshalAuthResponse(address)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineServerMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeAuthResponse {
|
||||
t.Errorf("expected %d, got %d", MsgTypeAuthResponse, msgType)
|
||||
}
|
||||
|
||||
respAddr, err := UnmarshalAuthResponse(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if respAddr != address {
|
||||
t.Errorf("expected %s, got %s", address, respAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalTransportMsg(t *testing.T) {
|
||||
peerID := HashID("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
|
||||
payload := []byte("payload")
|
||||
msg, err := MarshalTransportMsg(peerID, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineClientMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeTransport {
|
||||
t.Errorf("expected %d, got %d", MsgTypeTransport, msgType)
|
||||
}
|
||||
|
||||
uPeerID, err := UnmarshalTransportID(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal transport id: %v", err)
|
||||
}
|
||||
|
||||
if uPeerID.String() != peerID.String() {
|
||||
t.Errorf("expected %s, got %s", peerID, uPeerID)
|
||||
}
|
||||
|
||||
id, respPayload, err := UnmarshalTransportMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if id.String() != peerID.String() {
|
||||
t.Errorf("expected: '%s', got: '%s'", peerID, id)
|
||||
}
|
||||
|
||||
if string(respPayload) != string(payload) {
|
||||
t.Errorf("expected %s, got %s", payload, respPayload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalHealthcheck(t *testing.T) {
|
||||
msg := MarshalHealthcheck()
|
||||
|
||||
_, err := ValidateVersion(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
msgType, err := DetermineServerMessageType(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if msgType != MsgTypeHealthCheck {
|
||||
t.Errorf("expected %d, got %d", MsgTypeHealthCheck, msgType)
|
||||
}
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func MarshalSubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypeSubscribePeerState))
|
||||
}
|
||||
|
||||
func UnmarshalSubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
func MarshalUnsubPeerStateMsg(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypeUnsubscribePeerState))
|
||||
}
|
||||
|
||||
func UnmarshalUnsubPeerStateMsg(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
func MarshalPeersOnline(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypePeersOnline))
|
||||
}
|
||||
|
||||
func UnmarshalPeersOnlineMsg(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
func MarshalPeersWentOffline(ids []PeerID) ([][]byte, error) {
|
||||
return marshalPeerIDs(ids, byte(MsgTypePeersWentOffline))
|
||||
}
|
||||
|
||||
func UnMarshalPeersWentOffline(buf []byte) ([]PeerID, error) {
|
||||
return unmarshalPeerIDs(buf)
|
||||
}
|
||||
|
||||
// marshalPeerIDs is a generic function to marshal peer IDs with a specific message type
|
||||
func marshalPeerIDs(ids []PeerID, msgType byte) ([][]byte, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, fmt.Errorf("no list of peer ids provided")
|
||||
}
|
||||
|
||||
const maxPeersPerMessage = (MaxMessageSize - sizeOfProtoHeader) / peerIDSize
|
||||
var messages [][]byte
|
||||
|
||||
for i := 0; i < len(ids); i += maxPeersPerMessage {
|
||||
end := i + maxPeersPerMessage
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
chunk := ids[i:end]
|
||||
|
||||
totalSize := sizeOfProtoHeader + len(chunk)*peerIDSize
|
||||
buf := make([]byte, totalSize)
|
||||
buf[0] = byte(CurrentProtocolVersion)
|
||||
buf[1] = msgType
|
||||
|
||||
offset := sizeOfProtoHeader
|
||||
for _, id := range chunk {
|
||||
copy(buf[offset:], id[:])
|
||||
offset += peerIDSize
|
||||
}
|
||||
|
||||
messages = append(messages, buf)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// unmarshalPeerIDs is a generic function to unmarshal peer IDs from a buffer
|
||||
func unmarshalPeerIDs(buf []byte) ([]PeerID, error) {
|
||||
if len(buf) < sizeOfProtoHeader {
|
||||
return nil, fmt.Errorf("invalid message format")
|
||||
}
|
||||
|
||||
if (len(buf)-sizeOfProtoHeader)%peerIDSize != 0 {
|
||||
return nil, fmt.Errorf("invalid peer list size: %d", len(buf)-sizeOfProtoHeader)
|
||||
}
|
||||
|
||||
numIDs := (len(buf) - sizeOfProtoHeader) / peerIDSize
|
||||
|
||||
ids := make([]PeerID, numIDs)
|
||||
offset := sizeOfProtoHeader
|
||||
for i := 0; i < numIDs; i++ {
|
||||
copy(ids[i][:], buf[offset:offset+peerIDSize])
|
||||
offset += peerIDSize
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
testPeerCount = 10
|
||||
)
|
||||
|
||||
// Helper function to generate test PeerIDs
|
||||
func generateTestPeerIDs(n int) []PeerID {
|
||||
ids := make([]PeerID, n)
|
||||
for i := 0; i < n; i++ {
|
||||
for j := 0; j < peerIDSize; j++ {
|
||||
ids[i][j] = byte(i + j)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// Helper function to compare slices of PeerID
|
||||
func peerIDEqual(a, b []PeerID) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if !bytes.Equal(a[i][:], b[i][:]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalSubPeerState(t *testing.T) {
|
||||
ids := generateTestPeerIDs(testPeerCount)
|
||||
|
||||
msgs, err := MarshalSubPeerStateMsg(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var allIDs []PeerID
|
||||
for _, msg := range msgs {
|
||||
decoded, err := UnmarshalSubPeerStateMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
allIDs = append(allIDs, decoded...)
|
||||
}
|
||||
|
||||
if !peerIDEqual(ids, allIDs) {
|
||||
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalSubPeerState_EmptyInput(t *testing.T) {
|
||||
_, err := MarshalSubPeerStateMsg([]PeerID{})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalSubPeerState_Invalid(t *testing.T) {
|
||||
// Too short
|
||||
_, err := UnmarshalSubPeerStateMsg([]byte{1})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for short input")
|
||||
}
|
||||
|
||||
// Misaligned length
|
||||
buf := make([]byte, sizeOfProtoHeader+1)
|
||||
_, err = UnmarshalSubPeerStateMsg(buf)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for misaligned input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalPeersOnline(t *testing.T) {
|
||||
ids := generateTestPeerIDs(testPeerCount)
|
||||
|
||||
msgs, err := MarshalPeersOnline(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var allIDs []PeerID
|
||||
for _, msg := range msgs {
|
||||
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
allIDs = append(allIDs, decoded...)
|
||||
}
|
||||
|
||||
if !peerIDEqual(ids, allIDs) {
|
||||
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPeersOnline_EmptyInput(t *testing.T) {
|
||||
_, err := MarshalPeersOnline([]PeerID{})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalPeersOnline_Invalid(t *testing.T) {
|
||||
_, err := UnmarshalPeersOnlineMsg([]byte{1})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for short input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshalPeersWentOffline(t *testing.T) {
|
||||
ids := generateTestPeerIDs(testPeerCount)
|
||||
|
||||
msgs, err := MarshalPeersWentOffline(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var allIDs []PeerID
|
||||
for _, msg := range msgs {
|
||||
// MarshalPeersWentOffline shares no unmarshal function, so reuse PeersOnline
|
||||
decoded, err := UnmarshalPeersOnlineMsg(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
allIDs = append(allIDs, decoded...)
|
||||
}
|
||||
|
||||
if !peerIDEqual(ids, allIDs) {
|
||||
t.Errorf("expected %v, got %v", ids, allIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalPeersWentOffline_EmptyInput(t *testing.T) {
|
||||
_, err := MarshalPeersWentOffline([]PeerID{})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for empty input")
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
//nolint:staticcheck
|
||||
"github.com/netbirdio/netbird/relay/messages/address"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages/address"
|
||||
//nolint:staticcheck
|
||||
authmsg "github.com/netbirdio/netbird/relay/messages/auth"
|
||||
authmsg "github.com/netbirdio/netbird/shared/relay/messages/auth"
|
||||
)
|
||||
|
||||
type Validator interface {
|
||||
|
||||
@@ -10,10 +10,12 @@ import (
|
||||
|
||||
"github.com/coder/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay"
|
||||
)
|
||||
|
||||
// URLPath is the path for the websocket connection.
|
||||
const URLPath = "/relay"
|
||||
const URLPath = relay.WebSocketURLPath
|
||||
|
||||
type Listener struct {
|
||||
// Address is the address to listen on.
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
"github.com/netbirdio/netbird/relay/metrics"
|
||||
"github.com/netbirdio/netbird/relay/server/store"
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/quic"
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||
quictls "github.com/netbirdio/netbird/shared/relay/tls"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
|
||||
type event struct {
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
|
||||
type PeerNotifier struct {
|
||||
|
||||
@@ -3,7 +3,7 @@ package store
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
|
||||
type IPeer interface {
|
||||
|
||||
@@ -3,7 +3,7 @@ package store
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
|
||||
type MocPeer struct {
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
package tls
|
||||
|
||||
const nbalpn = "nb-quic"
|
||||
@@ -1,26 +0,0 @@
|
||||
//go:build devcert
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
func ClientQUICTLSConfig() *tls.Config {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true, // Debug mode allows insecure connections
|
||||
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
|
||||
RootCAs: certPool,
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build !devcert
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
func ClientQUICTLSConfig() *tls.Config {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
NextProtos: []string{nbalpn},
|
||||
RootCAs: certPool,
|
||||
}
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
// Package tls provides utilities for configuring and managing Transport Layer
|
||||
// Security (TLS) in server and client environments, with a focus on QUIC
|
||||
// protocol support and testing configurations.
|
||||
//
|
||||
// The package includes functions for cloning and customizing TLS
|
||||
// configurations as well as generating self-signed certificates for
|
||||
// development and testing purposes.
|
||||
//
|
||||
// Key Features:
|
||||
//
|
||||
// - `ServerQUICTLSConfig`: Creates a server-side TLS configuration tailored
|
||||
// for QUIC protocol with specified or default settings. QUIC requires a
|
||||
// specific TLS configuration with proper ALPN (Application-Layer Protocol
|
||||
// Negotiation) support, making the TLS settings crucial for establishing
|
||||
// secure connections.
|
||||
//
|
||||
// - `ClientQUICTLSConfig`: Provides a client-side TLS configuration suitable
|
||||
// for QUIC protocol. The configuration differs between development
|
||||
// (insecure testing) and production (strict verification).
|
||||
//
|
||||
// - `generateTestTLSConfig`: Generates a self-signed TLS configuration for
|
||||
// use in local development and testing scenarios.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// This package provides separate implementations for development and production
|
||||
// environments. The development implementation (guarded by `//go:build devcert`)
|
||||
// supports testing configurations with self-signed certificates and insecure
|
||||
// client connections. The production implementation (guarded by `//go:build
|
||||
// !devcert`) ensures that valid and secure TLS configurations are supplied
|
||||
// and used.
|
||||
//
|
||||
// The QUIC protocol is highly reliant on properly configured TLS settings,
|
||||
// and this package ensures that configurations meet the requirements for
|
||||
// secure and efficient QUIC communication.
|
||||
package tls
|
||||
@@ -1,79 +0,0 @@
|
||||
//go:build devcert
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
|
||||
if originTLSCfg == nil {
|
||||
log.Warnf("QUIC server will use self signed certificate for testing!")
|
||||
return generateTestTLSConfig()
|
||||
}
|
||||
|
||||
cfg := originTLSCfg.Clone()
|
||||
cfg.NextProtos = []string{nbalpn}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// GenerateTestTLSConfig creates a self-signed certificate for testing
|
||||
func generateTestTLSConfig() (*tls.Config, error) {
|
||||
log.Infof("generating test TLS config")
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Organization"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
NextProtos: []string{nbalpn},
|
||||
}, nil
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//go:build !devcert
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
|
||||
if originTLSCfg == nil {
|
||||
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
|
||||
}
|
||||
cfg := originTLSCfg.Clone()
|
||||
cfg.NextProtos = []string{nbalpn}
|
||||
return cfg, nil
|
||||
}
|
||||
Reference in New Issue
Block a user