Compare commits

..

1 Commits

Author SHA1 Message Date
Viktor Liu
417fa6e833 Change default interface name from wt0 to nb0 2025-07-21 17:58:56 +02:00
23 changed files with 122 additions and 239 deletions

View File

@@ -211,11 +211,7 @@ jobs:
strategy:
fail-fast: false
matrix:
include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -255,9 +251,9 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test ${{ matrix.raceFlag }} \
go test \
-exec 'sudo' \
-timeout 10m ./relay/...
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"

View File

@@ -3,4 +3,4 @@
package configurer
// WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "wt0"
const WgInterfaceDefault = "nb0"

View File

@@ -39,7 +39,7 @@ const (
)
var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}

View File

@@ -1393,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
ifaceName = fmt.Sprintf("nb%d", i)
}
wgPort := 33100 + i

View File

@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
}
func (m *mockIFaceMapper) Name() string {
return "wt0"
return "nb0"
}
func (m *mockIFaceMapper) Address() wgaddr.Address {

View File

@@ -24,7 +24,7 @@ type WorkerRelay struct {
isController bool
config ConnConfig
conn *Conn
relayManager *relayClient.Manager
relayManager relayClient.ManagerService
relayedConn net.Conn
relayLock sync.Mutex
@@ -34,7 +34,7 @@ type WorkerRelay struct {
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,

View File

@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wt0",
name: "nb0",
}
sysOps := &SysOps{

View File

@@ -292,7 +292,7 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return nil, err

View File

@@ -572,14 +572,10 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
if err = relayClient.Connect(ctx); err != nil {
err = relayClient.Connect(ctx)
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) {
@@ -595,7 +591,7 @@ func TestCloseByServer(t *testing.T) {
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Errorf("timeout waiting for client to disconnect")
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn(ctx, "bob")

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus"
)
const (
DefaultConnectionTimeout = 30 * time.Second
var (
connectionTimeout = 30 * time.Second
)
type DialeFn interface {
@@ -25,18 +25,16 @@ type dialResult struct {
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
connectionTimeout time.Duration
log *log.Entry
serverURL string
dialerFns []DialeFn
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
}
}
@@ -60,7 +58,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())

View File

@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
rd := NewRaceDial(logger, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
protocolStr: proto,
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
protocolStr: "proto2",
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
_ = conn.Close()
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
protocolStr: "proto1",
}
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
rd := NewRaceDial(logger, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
protocolStr: proto2,
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)

View File

@@ -8,8 +8,7 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// TODO: make it configurable, the manager should validate all configurable parameters
var (
reconnectingTimeout = 60 * time.Second
)

View File

@@ -39,6 +39,17 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a

View File

@@ -13,9 +13,7 @@ import (
)
func TestEmptyURL(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
mgr := NewManager(context.Background(), nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
@@ -218,11 +216,9 @@ func TestForeginConnClose(t *testing.T) {
}
}
func TestForeignAutoClose(t *testing.T) {
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
@@ -288,35 +284,16 @@ func TestForeignAutoClose(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err)
}
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer")
}
// Add the disconnect listener after the connection attempt
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
t.Logf("failed to add close listener (expected if connection failed): %s", err)
}
// Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
t.Logf("waiting for relay cleanup: %s", timeout)
select {
case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
time.Sleep(timeout)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}
t.Logf("closing manager")
@@ -324,6 +301,7 @@ func TestForeignAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
@@ -334,7 +312,8 @@ func TestAutoReconnect(t *testing.T) {
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(srvCfg); err != nil {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()

View File

@@ -4,76 +4,38 @@ import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(1 * time.Second):
// Test passes if no timeout received
}
}
func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second):
t.Error("timeout not received")
}
}
func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat()
@@ -97,18 +59,13 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))

View File

@@ -135,11 +135,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go sender.StartHealthCheck(ctx)
go func() {
responded := false
@@ -173,11 +169,6 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
t.Fatalf("should have timed out before %s", testTimeout)
}
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
})
}

View File

@@ -20,12 +20,12 @@ type Metrics struct {
TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peerReconnections metric.Int64Counter
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
}
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
@@ -80,13 +80,6 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err
}
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
)
if err != nil {
return nil, err
}
m := &Metrics{
Meter: meter,
TransferBytesSent: bytesSent,
@@ -94,7 +87,6 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers,
peerReconnections: peerReconnections,
ctx: ctx,
peerActivityChan: make(chan string, 10),
@@ -146,10 +138,6 @@ func (m *Metrics) PeerDisconnected(id string) {
delete(m.peerLastActive, id)
}
func (m *Metrics) RecordPeerReconnection() {
m.peerReconnections.Add(m.ctx, 1)
}
// PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) {
select {

View File

@@ -18,9 +18,12 @@ type Listener struct {
TLSConfig *tls.Config
listener *quic.Listener
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1452,
@@ -46,7 +49,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
acceptFn(conn)
l.acceptFn(conn)
}
}

View File

@@ -32,9 +32,6 @@ type Peer struct {
notifier *store.PeerNotifier
peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
notificationMutex sync.Mutex
}
// NewPeer creates a new Peer instance and prepare custom logging
@@ -244,16 +241,10 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
// collect online peers to response back to the caller
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
@@ -283,9 +274,6 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)

View File

@@ -86,13 +86,14 @@ func NewRelay(config Config) (*Relay, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
peerStore := store.NewStore()
r := &Relay{
metrics: m,
metricsCancel: metricsCancel,
validator: config.AuthValidator,
instanceURL: config.instanceURL,
store: store.NewStore(),
notifier: store.NewPeerNotifier(),
store: peerStore,
notifier: store.NewPeerNotifier(peerStore),
}
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@@ -130,18 +131,15 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
}
r.store.AddPeer(peer)
r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
if deleted := r.store.DeletePeer(peer); deleted {
r.notifier.PeerWentOffline(peer.ID())
}
r.notifier.PeerWentOffline(peer.ID())
r.store.DeletePeer(peer)
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()

View File

@@ -7,27 +7,24 @@ import (
"github.com/netbirdio/netbird/relay/messages"
)
type event struct {
peerID messages.PeerID
online bool
}
type Listener struct {
ctx context.Context
ctx context.Context
store *Store
eventChan chan *event
onlineChan chan messages.PeerID
offlineChan chan messages.PeerID
interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex
}
func newListener(ctx context.Context) *Listener {
func newListener(ctx context.Context, store *Store) *Listener {
l := &Listener{
ctx: ctx,
ctx: ctx,
store: store,
// important to use a single channel for offline and online events because with it we can ensure all events
// will be processed in the order they were sent
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
}
@@ -35,7 +32,8 @@ func newListener(ctx context.Context) *Listener {
return l
}
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
availablePeers := make([]messages.PeerID, 0)
l.mu.Lock()
defer l.mu.Unlock()
@@ -43,6 +41,17 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{}
}
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
}
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
@@ -52,6 +61,7 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id)
}
}
@@ -60,31 +70,26 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
select {
case <-l.ctx.Done():
return
case e := <-l.eventChan:
peersOffline := make([]messages.PeerID, 0)
peersOnline := make([]messages.PeerID, 0)
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
case pID := <-l.onlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.onlineChan) > 0 {
pID = <-l.onlineChan
peers = append(peers, pID)
}
// Drain the channel to collect all events
for len(l.eventChan) > 0 {
e = <-l.eventChan
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
onPeersComeOnline(peers)
case pID := <-l.offlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.offlineChan) > 0 {
pID = <-l.offlineChan
peers = append(peers, pID)
}
if len(peersOnline) > 0 {
onPeersComeOnline(peersOnline)
}
if len(peersOffline) > 0 {
onPeersWentOffline(peersOffline)
}
onPeersWentOffline(peers)
}
}
}
@@ -95,10 +100,7 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOffline[peerID]; ok {
select {
case l.eventChan <- &event{
peerID: peerID,
online: false,
}:
case l.offlineChan <- peerID:
case <-l.ctx.Done():
}
}
@@ -110,13 +112,9 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOnline[peerID]; ok {
select {
case l.eventChan <- &event{
peerID: peerID,
online: true,
}:
case l.onlineChan <- peerID:
case <-l.ctx.Done():
}
delete(l.interestedPeersForOnline, peerID)
}
}

View File

@@ -8,12 +8,15 @@ import (
)
type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex
}
func NewPeerNotifier() *PeerNotifier {
func NewPeerNotifier(store *Store) *PeerNotifier {
pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc),
}
return pn
@@ -21,7 +24,7 @@ func NewPeerNotifier() *PeerNotifier {
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx)
listener := newListener(ctx, pn.store)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock()

View File

@@ -26,9 +26,7 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
// If the peer already exists, it will be replaced and the old peer will be closed
// Returns true if the peer was replaced, false if it was added for the first time.
func (s *Store) AddPeer(peer IPeer) bool {
func (s *Store) AddPeer(peer IPeer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.ID()]
@@ -37,24 +35,22 @@ func (s *Store) AddPeer(peer IPeer) bool {
}
s.peers[peer.ID()] = peer
return ok
}
// DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer IPeer) bool {
func (s *Store) DeletePeer(peer IPeer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
dp, ok := s.peers[peer.ID()]
if !ok {
return false
return
}
if dp != peer {
return false
return
}
delete(s.peers, peer.ID())
return true
}
// Peer returns a peer by its ID
@@ -77,21 +73,3 @@ func (s *Store) Peers() []IPeer {
}
return peers
}
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
listener.AddInterestedPeers(peerIDs)
// Check for currently online peers
for _, id := range peerIDs {
if _, ok := s.peers[id]; ok {
onlinePeers = append(onlinePeers, id)
}
}
return onlinePeers
}