mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:24:18 -04:00
Compare commits
36 Commits
v0.62.0
...
feat/auto-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74eebeb95a | ||
|
|
8aa1b23a22 | ||
|
|
030ddae51e | ||
|
|
6eee52b56e | ||
|
|
9313b49625 | ||
|
|
18f884f769 | ||
|
|
1354096c4d | ||
|
|
cd19f4d910 | ||
|
|
bab5cd4b41 | ||
|
|
7d846bf9ba | ||
|
|
6200aaf0b0 | ||
|
|
7fa926d397 | ||
|
|
9ae48a062a | ||
|
|
582ff1ff8c | ||
|
|
5556ff36af | ||
|
|
d5ea408cb3 | ||
|
|
436d74094b | ||
|
|
b37ba44015 | ||
|
|
0d2ce56e12 | ||
|
|
723c418966 | ||
|
|
e04b989a12 | ||
|
|
b070304d46 | ||
|
|
ad3985ac63 | ||
|
|
50423399f2 | ||
|
|
02afd4e849 | ||
|
|
d19f829f65 | ||
|
|
ec47a84afe | ||
|
|
ecf1e9013e | ||
|
|
6025eb1962 | ||
|
|
59ae92cf8f | ||
|
|
d2e198bd76 | ||
|
|
58d48127e0 | ||
|
|
84501a3f56 | ||
|
|
762b9b7b56 | ||
|
|
c6328788ca | ||
|
|
bc59749859 |
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -34,7 +35,6 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -280,6 +280,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
@@ -75,6 +76,7 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -199,6 +201,9 @@ type Engine struct {
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updatemanager.UpdateManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
@@ -314,6 +319,10 @@ func (e *Engine) Stop() error {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.Stop()
|
||||
}
|
||||
|
||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||
@@ -500,6 +509,19 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.updateManager == nil {
|
||||
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
|
||||
}
|
||||
|
||||
e.updateManager.CheckUpdateSuccess(e.ctx)
|
||||
|
||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||
}
|
||||
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
@@ -712,10 +734,44 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||
if autoUpdateSettings == nil {
|
||||
return
|
||||
}
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
e.updateManager = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||
return
|
||||
}
|
||||
|
||||
// Start manager if needed
|
||||
if e.updateManager == nil {
|
||||
log.Infof("starting auto-update manager")
|
||||
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
|
||||
}
|
||||
e.updateManager.Start(e.ctx)
|
||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||
}
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
@@ -1386,16 +1442,9 @@ func (e *Engine) receiveSignalEvents() {
|
||||
|
||||
switch msg.GetBody().Type {
|
||||
case sProto.Body_OFFER, sProto.Body_ANSWER:
|
||||
offerAnswer, err := convertToOfferAnswer(msg)
|
||||
if err != nil {
|
||||
if err := e.handleOfferAnswer(msg, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
conn.OnRemoteAnswer(*offerAnswer)
|
||||
}
|
||||
case sProto.Body_CANDIDATE:
|
||||
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
|
||||
if err != nil {
|
||||
@@ -1423,6 +1472,20 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.signal.WaitStreamConnected()
|
||||
}
|
||||
|
||||
func (e *Engine) handleOfferAnswer(msg *sProto.Message, conn *peer.Conn) error {
|
||||
offerAnswer, err := convertToOfferAnswer(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
conn.OnRemoteAnswer(*offerAnswer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
var mappedIPs []string
|
||||
var ignoredIFaces = make(map[string]interface{})
|
||||
|
||||
388
client/internal/updatemanager/manager.go
Normal file
388
client/internal/updatemanager/manager.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
latestVersion = "latest"
|
||||
// this version will be ignored
|
||||
developmentVersion = "development"
|
||||
)
|
||||
|
||||
type UpdateInterface interface {
|
||||
StopWatch()
|
||||
SetDaemonVersion(newVersion string) bool
|
||||
SetOnUpdateListener(updateFn func())
|
||||
LatestVersion() *v.Version
|
||||
StartFetcher()
|
||||
}
|
||||
|
||||
type UpdateState struct {
|
||||
PreUpdateVersion string
|
||||
TargetVersion string
|
||||
}
|
||||
|
||||
func (u UpdateState) Name() string {
|
||||
return "autoUpdate"
|
||||
}
|
||||
|
||||
type UpdateManager struct {
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
lastTrigger time.Time
|
||||
mgmUpdateChan chan struct{}
|
||||
updateChannel chan struct{}
|
||||
currentVersion string
|
||||
update UpdateInterface
|
||||
wg sync.WaitGroup
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
expectedVersion *v.Version
|
||||
updateToLatestVersion bool
|
||||
|
||||
// updateMutex protect update and expectedVersion fields
|
||||
updateMutex sync.Mutex
|
||||
|
||||
// updateFunc is used for testing to mock the triggerUpdate behavior
|
||||
updateFunc func(ctx context.Context, targetVersion string) error
|
||||
}
|
||||
|
||||
func NewUpdateManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *UpdateManager {
|
||||
manager := &UpdateManager{
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess checks if the update was successful. It works without to start the update manager.
|
||||
func (u *UpdateManager) CheckUpdateSuccess(ctx context.Context) {
|
||||
u.updateStateManager(ctx)
|
||||
}
|
||||
|
||||
func (u *UpdateManager) Start(ctx context.Context) {
|
||||
if u.cancel != nil {
|
||||
log.Errorf("UpdateManager already started")
|
||||
return
|
||||
}
|
||||
|
||||
u.update.SetDaemonVersion(u.currentVersion)
|
||||
u.update.SetOnUpdateListener(func() {
|
||||
select {
|
||||
case u.updateChannel <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
go u.update.StartFetcher()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
u.cancel = cancel
|
||||
|
||||
u.wg.Add(1)
|
||||
go u.updateLoop(ctx)
|
||||
}
|
||||
|
||||
func (u *UpdateManager) SetVersion(expectedVersion string) {
|
||||
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
|
||||
if u.cancel == nil {
|
||||
log.Errorf("UpdateManager not started")
|
||||
return
|
||||
}
|
||||
|
||||
u.updateMutex.Lock()
|
||||
defer u.updateMutex.Unlock()
|
||||
if expectedVersion == latestVersion {
|
||||
u.updateToLatestVersion = true
|
||||
u.expectedVersion = nil
|
||||
} else {
|
||||
expectedSemVer, err := v.NewVersion(expectedVersion)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing version: %v", err)
|
||||
return
|
||||
}
|
||||
if u.expectedVersion != nil && u.expectedVersion.Equal(expectedSemVer) {
|
||||
return
|
||||
}
|
||||
u.expectedVersion = expectedSemVer
|
||||
u.updateToLatestVersion = false
|
||||
}
|
||||
|
||||
select {
|
||||
case u.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) Stop() {
|
||||
if u.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.cancel()
|
||||
u.updateMutex.Lock()
|
||||
if u.update != nil {
|
||||
u.update.StopWatch()
|
||||
u.update = nil
|
||||
}
|
||||
u.updateMutex.Unlock()
|
||||
|
||||
u.wg.Wait()
|
||||
}
|
||||
|
||||
func (u *UpdateManager) onContextCancel() {
|
||||
if u.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.updateMutex.Lock()
|
||||
defer u.updateMutex.Unlock()
|
||||
if u.update != nil {
|
||||
u.update.StopWatch()
|
||||
u.update = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) updateLoop(ctx context.Context) {
|
||||
defer u.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
u.onContextCancel()
|
||||
return
|
||||
case <-u.mgmUpdateChan:
|
||||
case <-u.updateChannel:
|
||||
log.Infof("fetched new version info")
|
||||
}
|
||||
|
||||
u.handleUpdate(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) handleUpdate(ctx context.Context) {
|
||||
var updateVersion *v.Version
|
||||
|
||||
u.updateMutex.Lock()
|
||||
if u.update == nil {
|
||||
u.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
expectedVersion := u.expectedVersion
|
||||
useLatest := u.updateToLatestVersion
|
||||
curLatestVersion := u.update.LatestVersion()
|
||||
u.updateMutex.Unlock()
|
||||
|
||||
switch {
|
||||
// Resolve "latest" to actual version
|
||||
case useLatest:
|
||||
if curLatestVersion == nil {
|
||||
log.Tracef("latest version not fetched yet")
|
||||
return
|
||||
}
|
||||
updateVersion = curLatestVersion
|
||||
// Update to specific version
|
||||
case expectedVersion != nil:
|
||||
updateVersion = expectedVersion
|
||||
default:
|
||||
log.Debugf("no expected version information set")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("checking update option, current version: %s, target version: %s", u.currentVersion, updateVersion)
|
||||
if !u.shouldUpdate(updateVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u.lastTrigger = time.Now()
|
||||
log.Debugf("Auto-update triggered, current version: %s, target version: %s", u.currentVersion, updateVersion)
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Automatically updating client",
|
||||
"Your client version is older than auto-update version set in Management, updating client now.",
|
||||
nil,
|
||||
)
|
||||
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "show"},
|
||||
)
|
||||
|
||||
updateState := UpdateState{
|
||||
PreUpdateVersion: u.currentVersion,
|
||||
TargetVersion: updateVersion.String(),
|
||||
}
|
||||
|
||||
if err := u.stateManager.UpdateState(updateState); err != nil {
|
||||
log.Warnf("failed to update state: %v", err)
|
||||
} else {
|
||||
if err = u.stateManager.PersistState(ctx); err != nil {
|
||||
log.Warnf("failed to persist state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := u.triggerUpdate(ctx, updateVersion.String()); err != nil {
|
||||
log.Errorf("Error triggering auto-update: %v", err)
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update failed",
|
||||
fmt.Sprintf("Auto-update failed: %v", err),
|
||||
nil,
|
||||
)
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "hide"},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) updateStateManager(ctx context.Context) {
|
||||
stateType := &UpdateState{}
|
||||
|
||||
u.stateManager.RegisterState(stateType)
|
||||
if err := u.stateManager.LoadState(stateType); err != nil {
|
||||
log.Errorf("failed to load state: %v", err)
|
||||
return
|
||||
}
|
||||
state := u.stateManager.GetState(stateType)
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
|
||||
updateState, ok := state.(*UpdateState)
|
||||
if !ok {
|
||||
log.Errorf("failed to cast state to UpdateState")
|
||||
return
|
||||
}
|
||||
log.Debugf("autoUpdate state loaded, %v", *updateState)
|
||||
if updateState.TargetVersion == u.currentVersion {
|
||||
log.Infof("published notification event")
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update completed",
|
||||
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", u.currentVersion),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
if err := u.stateManager.DeleteState(updateState); err != nil {
|
||||
log.Errorf("failed to delete state: %v", err)
|
||||
} else if err = u.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
if u.currentVersion == developmentVersion {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
}
|
||||
currentVersion, err := v.NewVersion(u.currentVersion)
|
||||
if err != nil {
|
||||
log.Errorf("error checking for update, error parsing version `%s`: %v", u.currentVersion, err)
|
||||
return false
|
||||
}
|
||||
if currentVersion.GreaterThanOrEqual(updateVersion) {
|
||||
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", u.currentVersion, updateVersion)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Since(u.lastTrigger) < 5*time.Minute {
|
||||
log.Debugf("skipping auto-update, last update was %s ago", time.Since(u.lastTrigger))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused
|
||||
tempDir, err := os.MkdirTemp("", "netbird-installer-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating temporary directory: %w", err)
|
||||
}
|
||||
|
||||
// Clean up temp directory on error
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
if err := os.RemoveAll(tempDir); err != nil {
|
||||
log.Errorf("error cleaning up temporary directory: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fileNameParts := strings.Split(fileURL, "/")
|
||||
out, err := os.Create(filepath.Join(tempDir, fileNameParts[len(fileNameParts)-1]))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating temporary file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Errorf("error closing temporary file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating file download request: %w", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error downloading file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Errorf("Error closing response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("error downloading update file, received status code: %d", resp.StatusCode)
|
||||
return "", fmt.Errorf("error downloading file, received status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error downloading file: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("downloaded update file to %s", out.Name())
|
||||
|
||||
success = true // Mark success to prevent cleanup
|
||||
return out.Name(), nil
|
||||
}
|
||||
213
client/internal/updatemanager/manager_test.go
Normal file
213
client/internal/updatemanager/manager_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
v "github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (u *UpdateManager) WithCustomVersionUpdate(versionUpdate UpdateInterface) *UpdateManager {
|
||||
u.update = versionUpdate
|
||||
return u
|
||||
}
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
v.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) LatestVersion() *v.Version {
|
||||
return v.latestVersion
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StartFetcher() {}
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: false,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(mockUpdate)
|
||||
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.updateFunc = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest")
|
||||
var triggeredInit bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredInit = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredInit = false
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
var triggeredLater bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredLater = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredLater = false
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Update to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(&versionUpdateMock{latestVersion: c.latestVersion})
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.updateFunc = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion)
|
||||
|
||||
var updateTriggered bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
|
||||
}
|
||||
updateTriggered = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
updateTriggered = false
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
123
client/internal/updatemanager/update_darwin.go
Normal file
123
client/internal/updatemanager/update_darwin.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||
)
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing only)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
|
||||
outBytes, err := cmd.Output()
|
||||
if err != nil && cmd.ProcessState.ExitCode() == 1 {
|
||||
// Not installed using pkg file, thus installed using Homebrew
|
||||
|
||||
return updateHomeBrew(ctx)
|
||||
}
|
||||
// Installed using pkg file
|
||||
path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(targetVersion))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading update file: %w", err)
|
||||
}
|
||||
|
||||
volume := "/"
|
||||
for _, v := range strings.Split(string(outBytes), "\n") {
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if strings.HasPrefix(trimmed, "volume: ") {
|
||||
volume = strings.Split(trimmed, ": ")[1]
|
||||
}
|
||||
}
|
||||
|
||||
cmd = exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running pkg file: %w", err)
|
||||
}
|
||||
err = cmd.Process.Release()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func updateHomeBrew(ctx context.Context) error {
|
||||
// Homebrew must be run as a non-root user
|
||||
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
|
||||
fileInfo, err := os.Stat("/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting homebrew installation path info: %w", err)
|
||||
}
|
||||
|
||||
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
|
||||
}
|
||||
|
||||
// Get username from UID
|
||||
installer, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error looking up brew installer user: %w", err)
|
||||
}
|
||||
userName := installer.Name
|
||||
// Get user HOME, required for brew to run correctly
|
||||
// https://github.com/Homebrew/brew/issues/15833
|
||||
homeDir := installer.HomeDir
|
||||
// Homebrew does not support installing specific versions
|
||||
// Thus it will always update to latest and ignore targetVersion
|
||||
upgradeArgs := []string{"-u", userName, "/opt/homebrew/bin/brew", "upgrade", "netbirdio/tap/netbird"}
|
||||
// Check if netbird-ui is installed
|
||||
cmd := exec.CommandContext(ctx, "brew", "info", "--json", "netbirdio/tap/netbird-ui")
|
||||
err = cmd.Run()
|
||||
if err == nil {
|
||||
// netbird-ui is installed
|
||||
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
|
||||
}
|
||||
cmd = exec.CommandContext(ctx, "sudo", upgradeArgs...)
|
||||
cmd.Env = append(cmd.Env, "HOME="+homeDir)
|
||||
|
||||
// Homebrew upgrade doesn't restart the client on its own
|
||||
// So we have to wait for it to finish running and ensure it's done
|
||||
// And then basically restart the netbird service
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running brew upgrade: %w", err)
|
||||
}
|
||||
|
||||
currentPID := os.Getpid()
|
||||
|
||||
// Restart netbird service after the fact
|
||||
// This is a workaround since attempting to restart using launchctl will kill the service and die before starting
|
||||
// the service again as it's a child process
|
||||
// using SIGTERM should ensure a clean shutdown
|
||||
process, err := os.FindProcess(currentPID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding current process: %w", err)
|
||||
}
|
||||
err = process.Signal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending SIGTERM to current process: %w", err)
|
||||
}
|
||||
// We're dying now, which should restart us
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func urlWithVersionArch(version string) string {
|
||||
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
|
||||
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||
}
|
||||
15
client/internal/updatemanager/update_freebsd.go
Normal file
15
client/internal/updatemanager/update_freebsd.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build freebsd
|
||||
|
||||
package updatemanager
|
||||
|
||||
import "context"
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// TODO: Implement
|
||||
return nil
|
||||
}
|
||||
15
client/internal/updatemanager/update_js.go
Normal file
15
client/internal/updatemanager/update_js.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build js
|
||||
|
||||
package updatemanager
|
||||
|
||||
import "context"
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// TODO: Implement
|
||||
return nil
|
||||
}
|
||||
15
client/internal/updatemanager/update_linux.go
Normal file
15
client/internal/updatemanager/update_linux.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build linux
|
||||
|
||||
package updatemanager
|
||||
|
||||
import "context"
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// TODO: Implement
|
||||
return nil
|
||||
}
|
||||
101
client/internal/updatemanager/update_windows.go
Normal file
101
client/internal/updatemanager/update_windows.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build windows
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||
|
||||
installerEXE installerType = "EXE"
|
||||
installerMSI installerType = "MSI"
|
||||
)
|
||||
|
||||
type installerType string
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
method := installation()
|
||||
return install(ctx, method, targetVersion)
|
||||
}
|
||||
|
||||
func installation() installerType {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, uninstallKeyPath64, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
k, err = registry.OpenKey(registry.LOCAL_MACHINE, uninstallKeyPath32, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return installerMSI
|
||||
} else {
|
||||
err = k.Close()
|
||||
if err != nil {
|
||||
log.Warnf("Error closing registry key: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err = k.Close()
|
||||
if err != nil {
|
||||
log.Warnf("Error closing registry key: %v", err)
|
||||
}
|
||||
}
|
||||
return installerEXE
|
||||
}
|
||||
|
||||
func install(ctx context.Context, installerType installerType, targetVersion string) error {
|
||||
path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(installerType, targetVersion))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("start installation %s", path)
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if installerType == installerEXE {
|
||||
cmd = exec.CommandContext(ctx, path, "/S")
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, "msiexec", "/quiet", "/i", path)
|
||||
}
|
||||
|
||||
// Detach the process from the parent
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Errorf("error starting installer: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cmd.Process.Release(); err != nil {
|
||||
log.Errorf("error releasing installer process: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("installer started successfully: %s", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
func urlWithVersionArch(it installerType, version string) string {
|
||||
var url string
|
||||
if it == installerEXE {
|
||||
url = exeDownloadURL
|
||||
} else {
|
||||
url = msiDownloadURL
|
||||
}
|
||||
url = strings.ReplaceAll(url, "%version", version)
|
||||
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||
}
|
||||
@@ -93,13 +93,14 @@ func main() {
|
||||
showLoginURL: flags.showLoginURL,
|
||||
showDebug: flags.showDebug,
|
||||
showProfiles: flags.showProfiles,
|
||||
showUpdate: flags.showUpdate,
|
||||
})
|
||||
|
||||
// Watch for theme/settings changes to update the icon.
|
||||
go watchSettingsChanges(a, client)
|
||||
|
||||
// Run in window mode if any UI flag was set.
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showUpdate {
|
||||
a.Run()
|
||||
return
|
||||
}
|
||||
@@ -127,6 +128,7 @@ type cliFlags struct {
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
errorMsg string
|
||||
showUpdate bool
|
||||
saveLogsInFile bool
|
||||
}
|
||||
|
||||
@@ -146,6 +148,7 @@ func parseFlags() *cliFlags {
|
||||
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
||||
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
||||
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
||||
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
|
||||
flag.Parse()
|
||||
return &flags
|
||||
}
|
||||
@@ -296,6 +299,8 @@ type serviceClient struct {
|
||||
mExitNodeDeselectAll *systray.MenuItem
|
||||
logFile string
|
||||
wLoginURL fyne.Window
|
||||
wUpdateProgress fyne.Window
|
||||
updateContextCancel context.CancelFunc
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -312,6 +317,7 @@ type newServiceClientArgs struct {
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
showProfiles bool
|
||||
showUpdate bool
|
||||
}
|
||||
|
||||
// newServiceClient instance constructor
|
||||
@@ -329,7 +335,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
|
||||
showAdvancedSettings: args.showSettings,
|
||||
showNetworks: args.showNetworks,
|
||||
update: version.NewUpdate("nb/client-ui"),
|
||||
update: version.NewUpdateAndStart("nb/client-ui"),
|
||||
}
|
||||
|
||||
s.eventHandler = newEventHandler(s)
|
||||
@@ -347,6 +353,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
s.showDebugUI()
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx)
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -392,6 +400,30 @@ func (s *serviceClient) updateIcon() {
|
||||
s.updateIndicationLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *serviceClient) showUpdateProgress(ctx context.Context) {
|
||||
s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
|
||||
loadingLabel := widget.NewLabel("Updating")
|
||||
s.wUpdateProgress.SetContent(container.NewGridWithRows(2, widget.NewLabel("Your client version is older than auto-update version set in Management, updating client now."), loadingLabel))
|
||||
s.wUpdateProgress.Show()
|
||||
go func() {
|
||||
dotCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
dotCount++
|
||||
dotCount %= 4
|
||||
loadingLabel.SetText(fmt.Sprintf("Updating%s", strings.Repeat(".", dotCount)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
s.wUpdateProgress.CenterOnScreen()
|
||||
s.wUpdateProgress.SetFixedSize(true)
|
||||
s.wUpdateProgress.SetCloseIntercept(func() {})
|
||||
s.wUpdateProgress.RequestFocus()
|
||||
}
|
||||
|
||||
func (s *serviceClient) showSettingsUI() {
|
||||
// Check if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
@@ -950,6 +982,29 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.updateExitNodes()
|
||||
}
|
||||
})
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
if windowAction, ok := event.Metadata["progress_window"]; ok {
|
||||
log.Debugf("window action: %v", windowAction)
|
||||
if windowAction == "show" {
|
||||
log.Debugf("Inside show")
|
||||
if s.updateContextCancel != nil {
|
||||
s.updateContextCancel()
|
||||
s.updateContextCancel = nil
|
||||
}
|
||||
|
||||
subCtx, cancel := context.WithCancel(s.ctx)
|
||||
go s.eventHandler.runSelfCommand(subCtx, "update", "true")
|
||||
s.updateContextCancel = cancel
|
||||
}
|
||||
if windowAction == "hide" {
|
||||
log.Debugf("Inside hide")
|
||||
if s.updateContextCancel != nil {
|
||||
s.updateContextCancel()
|
||||
s.updateContextCancel = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
|
||||
@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
||||
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
||||
|
||||
s.update = version.NewUpdate("nb/management")
|
||||
s.update = version.NewUpdateAndStart("nb/management")
|
||||
s.update.SetDaemonVersion(version.NetbirdVersion())
|
||||
s.update.SetOnUpdateListener(func() {
|
||||
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
|
||||
|
||||
@@ -340,7 +340,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
|
||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
||||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
@@ -376,6 +377,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -477,6 +479,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{
|
||||
"version": newSettings.AutoUpdateVersion,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
|
||||
@@ -180,6 +180,8 @@ const (
|
||||
UserApproved Activity = 89
|
||||
UserRejected Activity = 90
|
||||
|
||||
AccountAutoUpdateVersionUpdated Activity = 91
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -286,8 +288,11 @@ var activityMap = map[Activity]Code{
|
||||
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
|
||||
|
||||
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||
|
||||
UserApproved: {"User approved", "user.approve"},
|
||||
UserRejected: {"User rejected", "user.reject"},
|
||||
|
||||
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -712,6 +712,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
AutoUpdate: &proto.AutoUpdateSettings{
|
||||
Version: settings.AutoUpdateVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -719,9 +722,10 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package accounts
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -26,7 +28,9 @@ const (
|
||||
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
|
||||
MinNetworkBitsIPv4 = 28
|
||||
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
|
||||
MinNetworkBitsIPv6 = 120
|
||||
MinNetworkBitsIPv6 = 120
|
||||
disableAutoUpdate = "disabled"
|
||||
autoUpdateLatestVersion = "latest"
|
||||
)
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
@@ -162,6 +166,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
|
||||
returnSettings := &types.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
returnSettings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Settings.JwtGroupsEnabled != nil {
|
||||
returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||
}
|
||||
if req.Settings.GroupsPropagationEnabled != nil {
|
||||
returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
|
||||
}
|
||||
if req.Settings.JwtGroupsClaimName != nil {
|
||||
returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||
}
|
||||
if req.Settings.JwtAllowGroups != nil {
|
||||
returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
|
||||
returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
|
||||
}
|
||||
if req.Settings.DnsDomain != nil {
|
||||
returnSettings.DNSDomain = *req.Settings.DnsDomain
|
||||
}
|
||||
if req.Settings.LazyConnectionEnabled != nil {
|
||||
returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
}
|
||||
if req.Settings.AutoUpdateVersion != nil {
|
||||
_, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion)
|
||||
if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion ||
|
||||
*req.Settings.AutoUpdateVersion == disableAutoUpdate ||
|
||||
err == nil {
|
||||
returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion
|
||||
} else if *req.Settings.AutoUpdateVersion != "" {
|
||||
return nil, fmt.Errorf("invalid AutoUpdateVersion")
|
||||
}
|
||||
}
|
||||
|
||||
return returnSettings, nil
|
||||
}
|
||||
|
||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
@@ -186,45 +245,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
settings := &types.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
settings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Settings.JwtGroupsEnabled != nil {
|
||||
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||
}
|
||||
if req.Settings.GroupsPropagationEnabled != nil {
|
||||
settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
|
||||
}
|
||||
if req.Settings.JwtGroupsClaimName != nil {
|
||||
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||
}
|
||||
if req.Settings.JwtAllowGroups != nil {
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
|
||||
settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
|
||||
}
|
||||
if req.Settings.DnsDomain != nil {
|
||||
settings.DNSDomain = *req.Settings.DnsDomain
|
||||
}
|
||||
if req.Settings.LazyConnectionEnabled != nil {
|
||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
settings, err := h.updateAccountRequestSettings(req)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
}
|
||||
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
|
||||
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
|
||||
@@ -313,6 +336,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
|
||||
@@ -120,6 +120,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -142,6 +143,30 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK with autoUpdateVersion",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr(""),
|
||||
JwtGroupsEnabled: br(false),
|
||||
JwtAllowGroups: &[]string{},
|
||||
RegularUsersViewBlocked: false,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -164,6 +189,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -186,6 +212,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -208,6 +235,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
@@ -1575,6 +1575,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
PeerConfig: toPeerConfig(peer, network, dnsDomain, settings),
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
|
||||
@@ -52,6 +52,9 @@ type Settings struct {
|
||||
|
||||
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
|
||||
LazyConnectionEnabled bool `gorm:"default:false"`
|
||||
|
||||
// AutoUpdateVersion client auto-update version
|
||||
AutoUpdateVersion string `gorm:"default:'latest'"`
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
|
||||
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
||||
DNSDomain: s.DNSDomain,
|
||||
NetworkRange: s.NetworkRange,
|
||||
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
|
||||
@@ -145,6 +145,10 @@ components:
|
||||
description: Enables or disables experimental lazy connection
|
||||
type: boolean
|
||||
example: true
|
||||
auto_update_version:
|
||||
description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
|
||||
type: string
|
||||
example: "0.51.2"
|
||||
required:
|
||||
- peer_login_expiration_enabled
|
||||
- peer_login_expiration
|
||||
|
||||
@@ -291,6 +291,9 @@ type AccountRequest struct {
|
||||
|
||||
// AccountSettings defines model for AccountSettings.
|
||||
type AccountSettings struct {
|
||||
// AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
|
||||
AutoUpdateVersion *string `json:"auto_update_version,omitempty"`
|
||||
|
||||
// DnsDomain Allows to define a custom dns domain for the account
|
||||
DnsDomain *string `json:"dns_domain,omitempty"`
|
||||
Extra *AccountExtraSettings `json:"extra,omitempty"`
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -266,6 +266,18 @@ message PeerConfig {
|
||||
bool LazyConnectionEnabled = 6;
|
||||
|
||||
int32 mtu = 7;
|
||||
|
||||
// Auto-update config
|
||||
AutoUpdateSettings autoUpdate = 8;
|
||||
}
|
||||
|
||||
message AutoUpdateSettings {
|
||||
string version = 1;
|
||||
/*
|
||||
alwaysUpdate = true → Updates happen automatically in the background
|
||||
alwaysUpdate = false → Updates only happen when triggered by a peer connection
|
||||
*/
|
||||
bool alwaysUpdate = 2;
|
||||
}
|
||||
|
||||
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
||||
|
||||
@@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update {
|
||||
currentVersion, _ = goversion.NewVersion("0.0.0")
|
||||
}
|
||||
|
||||
latestAvailable, _ := goversion.NewVersion("0.0.0")
|
||||
|
||||
u := &Update{
|
||||
httpAgent: httpAgent,
|
||||
latestAvailable: latestAvailable,
|
||||
uiVersion: currentVersion,
|
||||
fetchTicker: time.NewTicker(fetchPeriod),
|
||||
fetchDone: make(chan struct{}),
|
||||
httpAgent: httpAgent,
|
||||
uiVersion: currentVersion,
|
||||
fetchDone: make(chan struct{}),
|
||||
}
|
||||
go u.startFetcher()
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func NewUpdateAndStart(httpAgent string) *Update {
|
||||
u := NewUpdate(httpAgent)
|
||||
go u.StartFetcher()
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// StopWatch stop the version info fetch loop
|
||||
func (u *Update) StopWatch() {
|
||||
if u.fetchTicker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.fetchTicker.Stop()
|
||||
|
||||
select {
|
||||
@@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Update) startFetcher() {
|
||||
func (u *Update) LatestVersion() *goversion.Version {
|
||||
u.versionsLock.Lock()
|
||||
defer u.versionsLock.Unlock()
|
||||
return u.latestAvailable
|
||||
}
|
||||
|
||||
func (u *Update) StartFetcher() {
|
||||
if u.fetchTicker != nil {
|
||||
return
|
||||
}
|
||||
u.fetchTicker = time.NewTicker(fetchPeriod)
|
||||
|
||||
if changed := u.fetchVersion(); changed {
|
||||
u.checkUpdate()
|
||||
}
|
||||
@@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool {
|
||||
u.versionsLock.Lock()
|
||||
defer u.versionsLock.Unlock()
|
||||
|
||||
if u.latestAvailable == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.latestAvailable.GreaterThan(u.uiVersion) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate(httpAgent)
|
||||
u := NewUpdateAndStart(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
@@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate(httpAgent)
|
||||
u := NewUpdateAndStart(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
@@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate(httpAgent)
|
||||
u := NewUpdateAndStart(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
|
||||
Reference in New Issue
Block a user