mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-08 18:46:34 -04:00
Compare commits
6 Commits
refactor/f
...
signal-sup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00e9d14652 | ||
|
|
a1fa0d79e5 | ||
|
|
00278a472f | ||
|
|
0bfda3c542 | ||
|
|
0d895750b4 | ||
|
|
5e4473310c |
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/signal/suppressor"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -20,13 +21,11 @@ var (
|
||||
// Peer representation of a connected Peer
|
||||
type Peer struct {
|
||||
// a unique id of the Peer (e.g. sha256 fingerprint of the Wireguard public key)
|
||||
Id string
|
||||
|
||||
StreamID int64
|
||||
|
||||
Id string
|
||||
suppressor *suppressor.Suppressor
|
||||
StreamID int64
|
||||
// a gRpc connection stream to the Peer
|
||||
Stream proto.SignalExchange_ConnectStreamServer
|
||||
|
||||
// registration time
|
||||
RegisteredAt time.Time
|
||||
|
||||
@@ -41,6 +40,7 @@ func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel
|
||||
StreamID: time.Now().UnixNano(),
|
||||
RegisteredAt: time.Now(),
|
||||
Cancel: cancel,
|
||||
suppressor: suppressor.NewSuppressor(nil),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,3 +117,10 @@ func (registry *Registry) Deregister(peer *Peer) {
|
||||
registry.metrics.Deregistrations.Add(context.Background(), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (peer *Peer) SendMessageAllowed(destination string, size int, arrivedTime time.Time) bool {
|
||||
if peer == nil || peer.suppressor == nil {
|
||||
return false
|
||||
}
|
||||
return peer.suppressor.PackageReceived(suppressor.PeerID(destination), size, arrivedTime)
|
||||
}
|
||||
|
||||
@@ -23,14 +23,17 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
labelType = "type"
|
||||
labelTypeError = "error"
|
||||
labelTypeNotConnected = "not_connected"
|
||||
labelTypeNotRegistered = "not_registered"
|
||||
labelTypeStream = "stream"
|
||||
labelTypeMessage = "message"
|
||||
labelTypeTimeout = "timeout"
|
||||
labelTypeDisconnected = "disconnected"
|
||||
labelType = "type"
|
||||
labelTypeError = "error"
|
||||
labelTypeNotConnected = "not_connected"
|
||||
labelTypeNotRegistered = "not_registered"
|
||||
labelTypeSenderNotRegistered = "sender_not_registered"
|
||||
labelTypeMessageSuppressed = "message_suppressed"
|
||||
labelTypeMessageSuppressedDisconnected = "message_suppressed_disconnected"
|
||||
labelTypeStream = "stream"
|
||||
labelTypeMessage = "message"
|
||||
labelTypeTimeout = "timeout"
|
||||
labelTypeDisconnected = "disconnected"
|
||||
|
||||
labelError = "error"
|
||||
labelErrorMissingId = "missing_id"
|
||||
@@ -95,6 +98,21 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
|
||||
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
|
||||
|
||||
peer, found := s.registry.Get(msg.Key)
|
||||
if !found {
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeSenderNotRegistered)))
|
||||
log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because sender peer is not registered", msg.Key, msg.RemoteKey)
|
||||
// return nil, status.Errorf(codes.FailedPrecondition, "peer not registered")
|
||||
} else if !peer.SendMessageAllowed(msg.RemoteKey, len(msg.Body), time.Now()) {
|
||||
if peer == nil {
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeMessageSuppressedDisconnected)))
|
||||
} else {
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeMessageSuppressed)))
|
||||
}
|
||||
s.metrics.MessageSize.Record(ctx, int64(len(msg.Body)), metric.WithAttributes(attribute.String(labelType, labelTypeMessageSuppressed)))
|
||||
log.Tracef("message from peer [%s] to peer [%s] suppressed due to repetition", msg.Key, msg.RemoteKey)
|
||||
}
|
||||
|
||||
if _, found := s.registry.Get(msg.RemoteKey); found {
|
||||
s.forwardMessageToPeer(ctx, msg)
|
||||
return &proto.EncryptedMessage{}, nil
|
||||
|
||||
155
signal/suppressor/suppressor.go
Normal file
155
signal/suppressor/suppressor.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package suppressor
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// DefaultRepetitionThreshold determines after how many repetitions it will be suppressed. It is a counter
|
||||
DefaultRepetitionThreshold = 90 // If the peer repeats the packets every 10 seconds, suppress them after 15 minutes
|
||||
minRepetitionThreshold = 3
|
||||
|
||||
// minTimeBetweenPackages below this period do not check the repetitions
|
||||
minTimeBetweenPackages = 7 * time.Second
|
||||
toleranceRange = 1 * time.Second
|
||||
)
|
||||
|
||||
type PeerID string
|
||||
|
||||
type packageStat struct {
|
||||
lastSeen time.Time // last packet timestamp
|
||||
lastDelta *time.Duration // time between same size of packages
|
||||
lastSize int
|
||||
repetitionTimes int
|
||||
}
|
||||
|
||||
type Opts struct {
|
||||
RepetitionThreshold int
|
||||
}
|
||||
|
||||
// Suppressor filters repeated packages from peers to prevent spam or abuse.
|
||||
//
|
||||
// It works by keeping track of the timing and size of packages received
|
||||
// from each peer. For each peer, it stores the last package size, the
|
||||
// timestamp when it was seen, the time difference (delta) between consecutive
|
||||
// packages of the same size, and a repetition counter.
|
||||
//
|
||||
// The suppressor uses the following rules:
|
||||
//
|
||||
// 1. **Short intervals**: If a package arrives sooner than minTimeBetweenPackages
|
||||
// since the last package, it is accepted without repetition checks. This
|
||||
// allows bursts or backoff recovery to pass through.
|
||||
//
|
||||
// 2. **Clock skew / negative delta**: If the system clock goes backward
|
||||
// and produces a negative delta, the package is accepted and the state
|
||||
// is reset to prevent exploitation.
|
||||
//
|
||||
// 3. **Size changes**: If the new package size differs from the previous
|
||||
// one, the package is accepted and the repetition counter is reset.
|
||||
//
|
||||
// 4. **Tolerance-based repetition detection**: If a package arrives with a
|
||||
// delta close to the previous delta (within the toleranceRange), it is
|
||||
// considered a repeated pattern and the repetition counter is incremented.
|
||||
//
|
||||
// 5. **Suppression**: Once the repetition counter exceeds repetitionThreshold,
|
||||
// further packages with the same timing pattern are suppressed.
|
||||
//
|
||||
// This design ensures that repeated or spammy traffic patterns are filtered
|
||||
// while allowing legitimate variations due to network jitter or bursty traffic.
|
||||
type Suppressor struct {
|
||||
repetitionThreshold int
|
||||
peers map[PeerID]*packageStat
|
||||
}
|
||||
|
||||
func NewSuppressor(opts *Opts) *Suppressor {
|
||||
threshold := DefaultRepetitionThreshold
|
||||
if opts != nil {
|
||||
threshold = opts.RepetitionThreshold
|
||||
if opts.RepetitionThreshold < minRepetitionThreshold {
|
||||
threshold = DefaultRepetitionThreshold
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return &Suppressor{
|
||||
repetitionThreshold: threshold,
|
||||
peers: make(map[PeerID]*packageStat),
|
||||
}
|
||||
}
|
||||
|
||||
// PackageReceived handles a newly received package from a peer.
|
||||
//
|
||||
// Parameters:
|
||||
// - destination: the PeerID of the peer that sent the package
|
||||
// - size: the size of the package
|
||||
// - arrivedTime: the timestamp when the package arrived
|
||||
//
|
||||
// Returns:
|
||||
// - true if the package is accepted (not suppressed)
|
||||
// - false if the package is considered a repeated package and suppressed
|
||||
func (s *Suppressor) PackageReceived(destination PeerID, size int, arrivedTime time.Time) bool {
|
||||
p, ok := s.peers[destination]
|
||||
if !ok {
|
||||
s.peers[destination] = &packageStat{
|
||||
lastSeen: arrivedTime,
|
||||
lastSize: size,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if p.lastSize != size {
|
||||
p.lastSeen = arrivedTime
|
||||
p.lastSize = size
|
||||
p.lastDelta = nil
|
||||
p.repetitionTimes = 0
|
||||
return true
|
||||
}
|
||||
|
||||
// Calculate delta
|
||||
delta := arrivedTime.Sub(p.lastSeen)
|
||||
|
||||
// Clock went backwards - don't reset state to prevent exploitation
|
||||
// Just update timestamp and continue with existing state
|
||||
if delta < 0 {
|
||||
p.lastSeen = arrivedTime
|
||||
p.lastDelta = nil
|
||||
p.repetitionTimes = 0
|
||||
return true
|
||||
}
|
||||
|
||||
// if it is below the threshold we want to allow because the backoff ticker is active
|
||||
if delta < minTimeBetweenPackages {
|
||||
p.lastSeen = arrivedTime
|
||||
p.lastDelta = nil
|
||||
p.repetitionTimes = 0
|
||||
return true
|
||||
}
|
||||
|
||||
// case when we have only one package in the history
|
||||
if p.lastDelta == nil {
|
||||
p.lastSeen = arrivedTime
|
||||
p.lastDelta = &delta
|
||||
return true
|
||||
}
|
||||
|
||||
if abs(delta-*p.lastDelta) > toleranceRange {
|
||||
p.lastSeen = arrivedTime
|
||||
p.lastDelta = &delta
|
||||
p.repetitionTimes = 0
|
||||
return true
|
||||
|
||||
}
|
||||
p.lastSeen = arrivedTime
|
||||
p.lastDelta = &delta
|
||||
p.repetitionTimes++
|
||||
|
||||
return p.repetitionTimes < s.repetitionThreshold
|
||||
}
|
||||
|
||||
func abs(d time.Duration) time.Duration {
|
||||
if d < 0 {
|
||||
return -d
|
||||
}
|
||||
return d
|
||||
}
|
||||
143
signal/suppressor/suppressor_test.go
Normal file
143
signal/suppressor/suppressor_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package suppressor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSuppressor_PackageReceived(t *testing.T) {
|
||||
destID := PeerID("remote")
|
||||
s := NewSuppressor(&Opts{RepetitionThreshold: 3})
|
||||
|
||||
// Define sequence with base deltas (s ±10% tolerance)
|
||||
deltas := []time.Duration{
|
||||
800 * time.Millisecond,
|
||||
1600 * time.Millisecond,
|
||||
3200 * time.Millisecond,
|
||||
6400 * time.Millisecond,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second, // should be suppressed
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
}
|
||||
sizes := []int{
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
}
|
||||
|
||||
expected := []bool{
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
}
|
||||
|
||||
// Apply ±10% tolerance
|
||||
times := make([]time.Time, len(deltas)+1)
|
||||
times[0] = time.Now()
|
||||
for i, d := range deltas {
|
||||
// ±10% randomization
|
||||
offset := d / 10
|
||||
times[i+1] = times[i].Add(d + offset) // for deterministic test, using +10%
|
||||
}
|
||||
|
||||
for i, arrival := range times[1:] {
|
||||
allowed := s.PackageReceived(destID, sizes[i], arrival)
|
||||
if allowed != expected[i] {
|
||||
t.Errorf("Packet %d at %v: expected allowed=%v, got %v", i+1, arrival.Sub(times[0]), expected[i], allowed)
|
||||
}
|
||||
t.Logf("Packet %d at %v allowed: %v", i+1, arrival.Sub(times[0]), allowed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuppressor_PackageReceivedReset(t *testing.T) {
|
||||
destID := PeerID("remote")
|
||||
s := NewSuppressor(&Opts{RepetitionThreshold: 5})
|
||||
|
||||
// Define sequence with base deltas (s ±10% tolerance)
|
||||
deltas := []time.Duration{
|
||||
800 * time.Millisecond,
|
||||
1600 * time.Millisecond,
|
||||
3200 * time.Millisecond,
|
||||
6400 * time.Millisecond,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
10 * time.Second,
|
||||
}
|
||||
sizes := []int{
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
50,
|
||||
100,
|
||||
100,
|
||||
100,
|
||||
}
|
||||
|
||||
expected := []bool{
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
}
|
||||
|
||||
// Apply ±10% tolerance
|
||||
times := make([]time.Time, len(deltas)+1)
|
||||
times[0] = time.Now()
|
||||
for i, d := range deltas {
|
||||
// ±10% randomization
|
||||
offset := d / 10
|
||||
times[i+1] = times[i].Add(d + offset) // for deterministic test, using +10%
|
||||
}
|
||||
|
||||
for i, arrival := range times[1:] {
|
||||
allowed := s.PackageReceived(destID, sizes[i], arrival)
|
||||
if allowed != expected[i] {
|
||||
t.Errorf("Packet %d at %v: expected allowed=%v, got %v", i+1, arrival.Sub(times[0]), expected[i], allowed)
|
||||
}
|
||||
t.Logf("Packet %d at %v allowed: %v", i+1, arrival.Sub(times[0]), allowed)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user