mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 00:54:01 -04:00
Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state transitions. The mutex now covers the full duration of each handler including the status check, the Up/Down call, and the flag update. Note: if Up or Down commands are triggered in parallel with sleep/wake events, the overall ordering of up/down/sleep/wake operations is still not guaranteed beyond what the mutex provides within the handler itself.
This commit is contained in:
80
client/internal/sleep/handler/handler.go
Normal file
80
client/internal/sleep/handler/handler.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Agent interface {
|
||||||
|
Up(ctx context.Context) error
|
||||||
|
Down(ctx context.Context) error
|
||||||
|
Status() (internal.StatusType, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SleepHandler struct {
|
||||||
|
agent Agent
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
|
||||||
|
sleepTriggeredDown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(agent Agent) *SleepHandler {
|
||||||
|
return &SleepHandler{
|
||||||
|
agent: agent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.sleepTriggeredDown {
|
||||||
|
log.Info("skipping up because wasn't sleep down")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||||
|
s.sleepTriggeredDown = false
|
||||||
|
|
||||||
|
log.Info("running up after wake up")
|
||||||
|
err := s.agent.Up(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("running up failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("running up command executed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
status, err := s.agent.Status()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||||
|
log.Infof("skipping setting the agent down because status is %s", status)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("running down after system started sleeping")
|
||||||
|
|
||||||
|
if err = s.agent.Down(ctx); err != nil {
|
||||||
|
log.Errorf("running down failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
log.Info("running down executed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
153
client/internal/sleep/handler/handler_test.go
Normal file
153
client/internal/sleep/handler/handler_test.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockAgent struct {
|
||||||
|
upErr error
|
||||||
|
downErr error
|
||||||
|
statusErr error
|
||||||
|
status internal.StatusType
|
||||||
|
upCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAgent) Up(_ context.Context) error {
|
||||||
|
m.upCalls++
|
||||||
|
return m.upErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAgent) Down(_ context.Context) error {
|
||||||
|
return m.downErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAgent) Status() (internal.StatusType, error) {
|
||||||
|
return m.status, m.statusErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
|
||||||
|
agent := &mockAgent{status: status}
|
||||||
|
return New(agent), agent
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||||
|
h, _ := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
// Even if Up fails, flag should be reset
|
||||||
|
_ = h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, agent.upCalls)
|
||||||
|
assert.False(t, h.sleepTriggeredDown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
agent.upErr = errors.New("up failed")
|
||||||
|
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, agent.upErr)
|
||||||
|
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
|
||||||
|
h, agent := newHandler(internal.StatusIdle)
|
||||||
|
h.sleepTriggeredDown = true
|
||||||
|
|
||||||
|
_ = h.HandleWakeUp(context.Background())
|
||||||
|
err := h.HandleWakeUp(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status internal.StatusType
|
||||||
|
}{
|
||||||
|
{"Idle", internal.StatusIdle},
|
||||||
|
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||||
|
{"LoginFailed", internal.StatusLoginFailed},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h, _ := newHandler(tt.status)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, h.sleepTriggeredDown)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
status internal.StatusType
|
||||||
|
}{
|
||||||
|
{"Connecting", internal.StatusConnecting},
|
||||||
|
{"Connected", internal.StatusConnected},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
h, _ := newHandler(tt.status)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, h.sleepTriggeredDown)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
|
||||||
|
agent := &mockAgent{statusErr: errors.New("status error")}
|
||||||
|
h := New(agent)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, agent.statusErr)
|
||||||
|
assert.False(t, h.sleepTriggeredDown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
|
||||||
|
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
|
||||||
|
h := New(agent)
|
||||||
|
|
||||||
|
err := h.HandleSleep(context.Background())
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, agent.downErr)
|
||||||
|
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
|
||||||
|
}
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
|
||||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
|
||||||
switch req.GetType() {
|
|
||||||
case proto.OSLifecycleRequest_WAKEUP:
|
|
||||||
return s.handleWakeUp(callerCtx)
|
|
||||||
case proto.OSLifecycleRequest_SLEEP:
|
|
||||||
return s.handleSleep(callerCtx)
|
|
||||||
default:
|
|
||||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
|
||||||
}
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
|
|
||||||
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
|
|
||||||
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
|
||||||
if !s.sleepTriggeredDown.Load() {
|
|
||||||
log.Info("skipping up because wasn't sleep down")
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
|
||||||
s.sleepTriggeredDown.Store(false)
|
|
||||||
|
|
||||||
log.Info("running up after wake up")
|
|
||||||
_, err := s.Up(callerCtx, &proto.UpRequest{})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("running up failed: %v", err)
|
|
||||||
return &proto.OSLifecycleResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("running up command executed successfully")
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
|
|
||||||
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
status, err := state.Status()
|
|
||||||
if err != nil {
|
|
||||||
s.mutex.Unlock()
|
|
||||||
return &proto.OSLifecycleResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
|
||||||
log.Infof("skipping setting the agent down because status is %s", status)
|
|
||||||
s.mutex.Unlock()
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
s.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Info("running down after system started sleeping")
|
|
||||||
|
|
||||||
_, err = s.Down(callerCtx, &proto.DownRequest{})
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("running down failed: %v", err)
|
|
||||||
return &proto.OSLifecycleResponse{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
log.Info("running down executed successfully")
|
|
||||||
return &proto.OSLifecycleResponse{}, nil
|
|
||||||
}
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestServer() *Server {
|
|
||||||
ctx := internal.CtxInitState(context.Background())
|
|
||||||
return &Server{
|
|
||||||
rootCtx: ctx,
|
|
||||||
statusRecorder: peer.NewRecorder(""),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
// sleepTriggeredDown is false by default
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusIdle)
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusNeedsLogin)
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusConnecting)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s.actCancel = cancel
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
|
||||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(internal.StatusConnected)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s.actCancel = cancel
|
|
||||||
|
|
||||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_SLEEP,
|
|
||||||
})
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
|
||||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
// Manually set the flag to simulate prior sleep down
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
|
|
||||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
// First wakeup without prior sleep - should be no-op
|
|
||||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
|
|
||||||
// Simulate prior sleep
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
// First wakeup after sleep - should reset flag
|
|
||||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
|
|
||||||
// Second wakeup - should be no-op
|
|
||||||
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
|
||||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
|
|
||||||
resp, err := s.handleWakeUp(context.Background())
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
s.sleepTriggeredDown.Store(true)
|
|
||||||
|
|
||||||
// Even if Up fails, flag should be reset
|
|
||||||
_, _ = s.handleWakeUp(context.Background())
|
|
||||||
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
status internal.StatusType
|
|
||||||
}{
|
|
||||||
{"Idle", internal.StatusIdle},
|
|
||||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
|
||||||
{"LoginFailed", internal.StatusLoginFailed},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(tt.status)
|
|
||||||
|
|
||||||
resp, err := s.handleSleep(context.Background())
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
assert.False(t, s.sleepTriggeredDown.Load())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
status internal.StatusType
|
|
||||||
}{
|
|
||||||
{"Connecting", internal.StatusConnecting},
|
|
||||||
{"Connected", internal.StatusConnected},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := newTestServer()
|
|
||||||
state := internal.CtxGetState(s.rootCtx)
|
|
||||||
state.Set(tt.status)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
s.actCancel = cancel
|
|
||||||
|
|
||||||
resp, err := s.handleSleep(ctx)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp)
|
|
||||||
assert.True(t, s.sleepTriggeredDown.Load())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
@@ -85,8 +86,7 @@ type Server struct {
|
|||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
|
||||||
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
|
sleepHandler *sleephandler.SleepHandler
|
||||||
sleepTriggeredDown atomic.Bool
|
|
||||||
|
|
||||||
jwtCache *jwtCache
|
jwtCache *jwtCache
|
||||||
}
|
}
|
||||||
@@ -100,7 +100,7 @@ type oauthAuthFlow struct {
|
|||||||
|
|
||||||
// New server instance constructor.
|
// New server instance constructor.
|
||||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
||||||
return &Server{
|
s := &Server{
|
||||||
rootCtx: ctx,
|
rootCtx: ctx,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
persistSyncResponse: true,
|
persistSyncResponse: true,
|
||||||
@@ -110,6 +110,10 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
|||||||
updateSettingsDisabled: updateSettingsDisabled,
|
updateSettingsDisabled: updateSettingsDisabled,
|
||||||
jwtCache: newJWTCache(),
|
jwtCache: newJWTCache(),
|
||||||
}
|
}
|
||||||
|
agent := &serverAgent{s}
|
||||||
|
s.sleepHandler = sleephandler.New(agent)
|
||||||
|
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
|
|||||||
46
client/server/sleep.go
Normal file
46
client/server/sleep.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||||
|
type serverAgent struct {
|
||||||
|
s *Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverAgent) Up(ctx context.Context) error {
|
||||||
|
_, err := a.s.Up(ctx, &proto.UpRequest{})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverAgent) Down(ctx context.Context) error {
|
||||||
|
_, err := a.s.Down(ctx, &proto.DownRequest{})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *serverAgent) Status() (internal.StatusType, error) {
|
||||||
|
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||||
|
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||||
|
switch req.GetType() {
|
||||||
|
case proto.OSLifecycleRequest_WAKEUP:
|
||||||
|
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
case proto.OSLifecycleRequest_SLEEP:
|
||||||
|
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||||
|
return &proto.OSLifecycleResponse{}, err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||||
|
}
|
||||||
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user