mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
Fix a race condition where a concurrent user-issued Up or Down command (#5418)
could interleave with a sleep/wake event causing out-of-order state transitions. The mutex now covers the full duration of each handler including the status check, the Up/Down call, and the flag update. Note: if Up or Down commands are triggered in parallel with sleep/wake events, the overall ordering of up/down/sleep/wake operations is still not guaranteed beyond what the mutex provides within the handler itself.
This commit is contained in:
80
client/internal/sleep/handler/handler.go
Normal file
80
client/internal/sleep/handler/handler.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type Agent interface {
|
||||
Up(ctx context.Context) error
|
||||
Down(ctx context.Context) error
|
||||
Status() (internal.StatusType, error)
|
||||
}
|
||||
|
||||
type SleepHandler struct {
|
||||
agent Agent
|
||||
|
||||
mu sync.Mutex
|
||||
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
|
||||
sleepTriggeredDown bool
|
||||
}
|
||||
|
||||
func New(agent Agent) *SleepHandler {
|
||||
return &SleepHandler{
|
||||
agent: agent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.sleepTriggeredDown {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown = false
|
||||
|
||||
log.Info("running up after wake up")
|
||||
err := s.agent.Up(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, err := s.agent.Status()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
if err = s.agent.Down(ctx); err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown = true
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return nil
|
||||
}
|
||||
153
client/internal/sleep/handler/handler_test.go
Normal file
153
client/internal/sleep/handler/handler_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockAgent struct {
|
||||
upErr error
|
||||
downErr error
|
||||
statusErr error
|
||||
status internal.StatusType
|
||||
upCalls int
|
||||
}
|
||||
|
||||
func (m *mockAgent) Up(_ context.Context) error {
|
||||
m.upCalls++
|
||||
return m.upErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Down(_ context.Context) error {
|
||||
return m.downErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Status() (internal.StatusType, error) {
|
||||
return m.status, m.statusErr
|
||||
}
|
||||
|
||||
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
|
||||
agent := &mockAgent{status: status}
|
||||
return New(agent), agent
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
h, _ := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
agent.upErr = errors.New("up failed")
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.upErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
|
||||
agent := &mockAgent{statusErr: errors.New("status error")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.statusErr)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
|
||||
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.downErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
return s.handleWakeUp(callerCtx)
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
return s.handleSleep(callerCtx)
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
}
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
|
||||
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
|
||||
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||
if !s.sleepTriggeredDown.Load() {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown.Store(false)
|
||||
|
||||
log.Info("running up after wake up")
|
||||
_, err := s.Up(callerCtx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
|
||||
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
|
||||
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
|
||||
s.mutex.Lock()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
s.mutex.Unlock()
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
s.mutex.Unlock()
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
s.mutex.Unlock()
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
_, err = s.Down(callerCtx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
return &Server{
|
||||
rootCtx: ctx,
|
||||
statusRecorder: peer.NewRecorder(""),
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// sleepTriggeredDown is false by default
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusIdle)
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusNeedsLogin)
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(internal.StatusConnected)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_SLEEP,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
|
||||
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// Manually set the flag to simulate prior sleep down
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
|
||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
|
||||
}
|
||||
|
||||
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
// First wakeup without prior sleep - should be no-op
|
||||
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
// Simulate prior sleep
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// First wakeup after sleep - should reset flag
|
||||
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
|
||||
// Second wakeup - should be no-op
|
||||
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
|
||||
Type: proto.OSLifecycleRequest_WAKEUP,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
resp, err := s.handleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
s := newTestServer()
|
||||
s.sleepTriggeredDown.Store(true)
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_, _ = s.handleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := newTestServer()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(tt.status)
|
||||
|
||||
resp, err := s.handleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.False(t, s.sleepTriggeredDown.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := newTestServer()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
state.Set(tt.status)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
s.actCancel = cancel
|
||||
|
||||
resp, err := s.handleSleep(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.True(t, s.sleepTriggeredDown.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -85,8 +86,7 @@ type Server struct {
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
|
||||
sleepTriggeredDown atomic.Bool
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
@@ -100,7 +100,7 @@ type oauthAuthFlow struct {
|
||||
|
||||
// New server instance constructor.
|
||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
||||
return &Server{
|
||||
s := &Server{
|
||||
rootCtx: ctx,
|
||||
logFile: logFile,
|
||||
persistSyncResponse: true,
|
||||
@@ -110,6 +110,10 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
agent := &serverAgent{s}
|
||||
s.sleepHandler = sleephandler.New(agent)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
|
||||
46
client/server/sleep.go
Normal file
46
client/server/sleep.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||
type serverAgent struct {
|
||||
s *Server
|
||||
}
|
||||
|
||||
func (a *serverAgent) Up(ctx context.Context) error {
|
||||
_, err := a.s.Up(ctx, &proto.UpRequest{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *serverAgent) Down(ctx context.Context) error {
|
||||
_, err := a.s.Down(ctx, &proto.DownRequest{})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *serverAgent) Status() (internal.StatusType, error) {
|
||||
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||
}
|
||||
|
||||
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||
switch req.GetType() {
|
||||
case proto.OSLifecycleRequest_WAKEUP:
|
||||
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
case proto.OSLifecycleRequest_SLEEP:
|
||||
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||
return &proto.OSLifecycleResponse{}, err
|
||||
}
|
||||
default:
|
||||
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||
}
|
||||
return &proto.OSLifecycleResponse{}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user