mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 08:22:27 -04:00
Compare commits
39 Commits
prototype/
...
feat/auto-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74eebeb95a | ||
|
|
8aa1b23a22 | ||
|
|
030ddae51e | ||
|
|
709e24eb6f | ||
|
|
6654e2dbf7 | ||
|
|
d80d47a469 | ||
|
|
6eee52b56e | ||
|
|
9313b49625 | ||
|
|
18f884f769 | ||
|
|
1354096c4d | ||
|
|
cd19f4d910 | ||
|
|
bab5cd4b41 | ||
|
|
7d846bf9ba | ||
|
|
6200aaf0b0 | ||
|
|
7fa926d397 | ||
|
|
9ae48a062a | ||
|
|
582ff1ff8c | ||
|
|
5556ff36af | ||
|
|
d5ea408cb3 | ||
|
|
436d74094b | ||
|
|
b37ba44015 | ||
|
|
0d2ce56e12 | ||
|
|
723c418966 | ||
|
|
e04b989a12 | ||
|
|
b070304d46 | ||
|
|
ad3985ac63 | ||
|
|
50423399f2 | ||
|
|
02afd4e849 | ||
|
|
d19f829f65 | ||
|
|
ec47a84afe | ||
|
|
ecf1e9013e | ||
|
|
6025eb1962 | ||
|
|
59ae92cf8f | ||
|
|
d2e198bd76 | ||
|
|
58d48127e0 | ||
|
|
84501a3f56 | ||
|
|
762b9b7b56 | ||
|
|
c6328788ca | ||
|
|
bc59749859 |
@@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -34,7 +35,6 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -280,6 +280,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||
mutex.prof: Mutex profiling information.
|
||||
goroutine.prof: Goroutine profiling information.
|
||||
block.prof: Block profiling information.
|
||||
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Adding state file from: %s", path)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
|
||||
@@ -50,6 +50,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
@@ -75,6 +76,7 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -199,6 +201,9 @@ type Engine struct {
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updatemanager.UpdateManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
@@ -314,6 +319,10 @@ func (e *Engine) Stop() error {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.Stop()
|
||||
}
|
||||
|
||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||
@@ -500,6 +509,19 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.updateManager == nil {
|
||||
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
|
||||
}
|
||||
|
||||
e.updateManager.CheckUpdateSuccess(e.ctx)
|
||||
|
||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||
}
|
||||
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
@@ -712,10 +734,44 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||
if autoUpdateSettings == nil {
|
||||
return
|
||||
}
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
e.updateManager = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||
return
|
||||
}
|
||||
|
||||
// Start manager if needed
|
||||
if e.updateManager == nil {
|
||||
log.Infof("starting auto-update manager")
|
||||
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
|
||||
}
|
||||
e.updateManager.Start(e.ctx)
|
||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||
}
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
@@ -1386,16 +1442,9 @@ func (e *Engine) receiveSignalEvents() {
|
||||
|
||||
switch msg.GetBody().Type {
|
||||
case sProto.Body_OFFER, sProto.Body_ANSWER:
|
||||
offerAnswer, err := convertToOfferAnswer(msg)
|
||||
if err != nil {
|
||||
if err := e.handleOfferAnswer(msg, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
conn.OnRemoteAnswer(*offerAnswer)
|
||||
}
|
||||
case sProto.Body_CANDIDATE:
|
||||
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
|
||||
if err != nil {
|
||||
@@ -1423,6 +1472,20 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.signal.WaitStreamConnected()
|
||||
}
|
||||
|
||||
func (e *Engine) handleOfferAnswer(msg *sProto.Message, conn *peer.Conn) error {
|
||||
offerAnswer, err := convertToOfferAnswer(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
conn.OnRemoteAnswer(*offerAnswer)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
var mappedIPs []string
|
||||
var ignoredIFaces = make(map[string]interface{})
|
||||
|
||||
388
client/internal/updatemanager/manager.go
Normal file
388
client/internal/updatemanager/manager.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
latestVersion = "latest"
|
||||
// this version will be ignored
|
||||
developmentVersion = "development"
|
||||
)
|
||||
|
||||
type UpdateInterface interface {
|
||||
StopWatch()
|
||||
SetDaemonVersion(newVersion string) bool
|
||||
SetOnUpdateListener(updateFn func())
|
||||
LatestVersion() *v.Version
|
||||
StartFetcher()
|
||||
}
|
||||
|
||||
type UpdateState struct {
|
||||
PreUpdateVersion string
|
||||
TargetVersion string
|
||||
}
|
||||
|
||||
func (u UpdateState) Name() string {
|
||||
return "autoUpdate"
|
||||
}
|
||||
|
||||
type UpdateManager struct {
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
lastTrigger time.Time
|
||||
mgmUpdateChan chan struct{}
|
||||
updateChannel chan struct{}
|
||||
currentVersion string
|
||||
update UpdateInterface
|
||||
wg sync.WaitGroup
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
expectedVersion *v.Version
|
||||
updateToLatestVersion bool
|
||||
|
||||
// updateMutex protect update and expectedVersion fields
|
||||
updateMutex sync.Mutex
|
||||
|
||||
// updateFunc is used for testing to mock the triggerUpdate behavior
|
||||
updateFunc func(ctx context.Context, targetVersion string) error
|
||||
}
|
||||
|
||||
func NewUpdateManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *UpdateManager {
|
||||
manager := &UpdateManager{
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess checks if the update was successful. It works without to start the update manager.
|
||||
func (u *UpdateManager) CheckUpdateSuccess(ctx context.Context) {
|
||||
u.updateStateManager(ctx)
|
||||
}
|
||||
|
||||
func (u *UpdateManager) Start(ctx context.Context) {
|
||||
if u.cancel != nil {
|
||||
log.Errorf("UpdateManager already started")
|
||||
return
|
||||
}
|
||||
|
||||
u.update.SetDaemonVersion(u.currentVersion)
|
||||
u.update.SetOnUpdateListener(func() {
|
||||
select {
|
||||
case u.updateChannel <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
go u.update.StartFetcher()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
u.cancel = cancel
|
||||
|
||||
u.wg.Add(1)
|
||||
go u.updateLoop(ctx)
|
||||
}
|
||||
|
||||
func (u *UpdateManager) SetVersion(expectedVersion string) {
|
||||
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
|
||||
if u.cancel == nil {
|
||||
log.Errorf("UpdateManager not started")
|
||||
return
|
||||
}
|
||||
|
||||
u.updateMutex.Lock()
|
||||
defer u.updateMutex.Unlock()
|
||||
if expectedVersion == latestVersion {
|
||||
u.updateToLatestVersion = true
|
||||
u.expectedVersion = nil
|
||||
} else {
|
||||
expectedSemVer, err := v.NewVersion(expectedVersion)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing version: %v", err)
|
||||
return
|
||||
}
|
||||
if u.expectedVersion != nil && u.expectedVersion.Equal(expectedSemVer) {
|
||||
return
|
||||
}
|
||||
u.expectedVersion = expectedSemVer
|
||||
u.updateToLatestVersion = false
|
||||
}
|
||||
|
||||
select {
|
||||
case u.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) Stop() {
|
||||
if u.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.cancel()
|
||||
u.updateMutex.Lock()
|
||||
if u.update != nil {
|
||||
u.update.StopWatch()
|
||||
u.update = nil
|
||||
}
|
||||
u.updateMutex.Unlock()
|
||||
|
||||
u.wg.Wait()
|
||||
}
|
||||
|
||||
func (u *UpdateManager) onContextCancel() {
|
||||
if u.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.updateMutex.Lock()
|
||||
defer u.updateMutex.Unlock()
|
||||
if u.update != nil {
|
||||
u.update.StopWatch()
|
||||
u.update = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) updateLoop(ctx context.Context) {
|
||||
defer u.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
u.onContextCancel()
|
||||
return
|
||||
case <-u.mgmUpdateChan:
|
||||
case <-u.updateChannel:
|
||||
log.Infof("fetched new version info")
|
||||
}
|
||||
|
||||
u.handleUpdate(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) handleUpdate(ctx context.Context) {
|
||||
var updateVersion *v.Version
|
||||
|
||||
u.updateMutex.Lock()
|
||||
if u.update == nil {
|
||||
u.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
expectedVersion := u.expectedVersion
|
||||
useLatest := u.updateToLatestVersion
|
||||
curLatestVersion := u.update.LatestVersion()
|
||||
u.updateMutex.Unlock()
|
||||
|
||||
switch {
|
||||
// Resolve "latest" to actual version
|
||||
case useLatest:
|
||||
if curLatestVersion == nil {
|
||||
log.Tracef("latest version not fetched yet")
|
||||
return
|
||||
}
|
||||
updateVersion = curLatestVersion
|
||||
// Update to specific version
|
||||
case expectedVersion != nil:
|
||||
updateVersion = expectedVersion
|
||||
default:
|
||||
log.Debugf("no expected version information set")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("checking update option, current version: %s, target version: %s", u.currentVersion, updateVersion)
|
||||
if !u.shouldUpdate(updateVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
u.lastTrigger = time.Now()
|
||||
log.Debugf("Auto-update triggered, current version: %s, target version: %s", u.currentVersion, updateVersion)
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Automatically updating client",
|
||||
"Your client version is older than auto-update version set in Management, updating client now.",
|
||||
nil,
|
||||
)
|
||||
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "show"},
|
||||
)
|
||||
|
||||
updateState := UpdateState{
|
||||
PreUpdateVersion: u.currentVersion,
|
||||
TargetVersion: updateVersion.String(),
|
||||
}
|
||||
|
||||
if err := u.stateManager.UpdateState(updateState); err != nil {
|
||||
log.Warnf("failed to update state: %v", err)
|
||||
} else {
|
||||
if err = u.stateManager.PersistState(ctx); err != nil {
|
||||
log.Warnf("failed to persist state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := u.triggerUpdate(ctx, updateVersion.String()); err != nil {
|
||||
log.Errorf("Error triggering auto-update: %v", err)
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update failed",
|
||||
fmt.Sprintf("Auto-update failed: %v", err),
|
||||
nil,
|
||||
)
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "hide"},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) updateStateManager(ctx context.Context) {
|
||||
stateType := &UpdateState{}
|
||||
|
||||
u.stateManager.RegisterState(stateType)
|
||||
if err := u.stateManager.LoadState(stateType); err != nil {
|
||||
log.Errorf("failed to load state: %v", err)
|
||||
return
|
||||
}
|
||||
state := u.stateManager.GetState(stateType)
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
|
||||
updateState, ok := state.(*UpdateState)
|
||||
if !ok {
|
||||
log.Errorf("failed to cast state to UpdateState")
|
||||
return
|
||||
}
|
||||
log.Debugf("autoUpdate state loaded, %v", *updateState)
|
||||
if updateState.TargetVersion == u.currentVersion {
|
||||
log.Infof("published notification event")
|
||||
u.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update completed",
|
||||
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", u.currentVersion),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
if err := u.stateManager.DeleteState(updateState); err != nil {
|
||||
log.Errorf("failed to delete state: %v", err)
|
||||
} else if err = u.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
if u.currentVersion == developmentVersion {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
}
|
||||
currentVersion, err := v.NewVersion(u.currentVersion)
|
||||
if err != nil {
|
||||
log.Errorf("error checking for update, error parsing version `%s`: %v", u.currentVersion, err)
|
||||
return false
|
||||
}
|
||||
if currentVersion.GreaterThanOrEqual(updateVersion) {
|
||||
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", u.currentVersion, updateVersion)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Since(u.lastTrigger) < 5*time.Minute {
|
||||
log.Debugf("skipping auto-update, last update was %s ago", time.Since(u.lastTrigger))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused
|
||||
tempDir, err := os.MkdirTemp("", "netbird-installer-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating temporary directory: %w", err)
|
||||
}
|
||||
|
||||
// Clean up temp directory on error
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
if err := os.RemoveAll(tempDir); err != nil {
|
||||
log.Errorf("error cleaning up temporary directory: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fileNameParts := strings.Split(fileURL, "/")
|
||||
out, err := os.Create(filepath.Join(tempDir, fileNameParts[len(fileNameParts)-1]))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating temporary file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Errorf("error closing temporary file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating file download request: %w", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error downloading file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Errorf("Error closing response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("error downloading update file, received status code: %d", resp.StatusCode)
|
||||
return "", fmt.Errorf("error downloading file, received status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error downloading file: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("downloaded update file to %s", out.Name())
|
||||
|
||||
success = true // Mark success to prevent cleanup
|
||||
return out.Name(), nil
|
||||
}
|
||||
213
client/internal/updatemanager/manager_test.go
Normal file
213
client/internal/updatemanager/manager_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
v "github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (u *UpdateManager) WithCustomVersionUpdate(versionUpdate UpdateInterface) *UpdateManager {
|
||||
u.update = versionUpdate
|
||||
return u
|
||||
}
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
v.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) LatestVersion() *v.Version {
|
||||
return v.latestVersion
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StartFetcher() {}
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: false,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(mockUpdate)
|
||||
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.updateFunc = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest")
|
||||
var triggeredInit bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredInit = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredInit = false
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
var triggeredLater bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredLater = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredLater = false
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Update to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(&versionUpdateMock{latestVersion: c.latestVersion})
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.updateFunc = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion)
|
||||
|
||||
var updateTriggered bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
|
||||
}
|
||||
updateTriggered = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
updateTriggered = false
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
123
client/internal/updatemanager/update_darwin.go
Normal file
123
client/internal/updatemanager/update_darwin.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||
)
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing only)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
|
||||
outBytes, err := cmd.Output()
|
||||
if err != nil && cmd.ProcessState.ExitCode() == 1 {
|
||||
// Not installed using pkg file, thus installed using Homebrew
|
||||
|
||||
return updateHomeBrew(ctx)
|
||||
}
|
||||
// Installed using pkg file
|
||||
path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(targetVersion))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error downloading update file: %w", err)
|
||||
}
|
||||
|
||||
volume := "/"
|
||||
for _, v := range strings.Split(string(outBytes), "\n") {
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if strings.HasPrefix(trimmed, "volume: ") {
|
||||
volume = strings.Split(trimmed, ": ")[1]
|
||||
}
|
||||
}
|
||||
|
||||
cmd = exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running pkg file: %w", err)
|
||||
}
|
||||
err = cmd.Process.Release()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func updateHomeBrew(ctx context.Context) error {
|
||||
// Homebrew must be run as a non-root user
|
||||
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
|
||||
fileInfo, err := os.Stat("/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting homebrew installation path info: %w", err)
|
||||
}
|
||||
|
||||
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
|
||||
}
|
||||
|
||||
// Get username from UID
|
||||
installer, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error looking up brew installer user: %w", err)
|
||||
}
|
||||
userName := installer.Name
|
||||
// Get user HOME, required for brew to run correctly
|
||||
// https://github.com/Homebrew/brew/issues/15833
|
||||
homeDir := installer.HomeDir
|
||||
// Homebrew does not support installing specific versions
|
||||
// Thus it will always update to latest and ignore targetVersion
|
||||
upgradeArgs := []string{"-u", userName, "/opt/homebrew/bin/brew", "upgrade", "netbirdio/tap/netbird"}
|
||||
// Check if netbird-ui is installed
|
||||
cmd := exec.CommandContext(ctx, "brew", "info", "--json", "netbirdio/tap/netbird-ui")
|
||||
err = cmd.Run()
|
||||
if err == nil {
|
||||
// netbird-ui is installed
|
||||
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
|
||||
}
|
||||
cmd = exec.CommandContext(ctx, "sudo", upgradeArgs...)
|
||||
cmd.Env = append(cmd.Env, "HOME="+homeDir)
|
||||
|
||||
// Homebrew upgrade doesn't restart the client on its own
|
||||
// So we have to wait for it to finish running and ensure it's done
|
||||
// And then basically restart the netbird service
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running brew upgrade: %w", err)
|
||||
}
|
||||
|
||||
currentPID := os.Getpid()
|
||||
|
||||
// Restart netbird service after the fact
|
||||
// This is a workaround since attempting to restart using launchctl will kill the service and die before starting
|
||||
// the service again as it's a child process
|
||||
// using SIGTERM should ensure a clean shutdown
|
||||
process, err := os.FindProcess(currentPID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error finding current process: %w", err)
|
||||
}
|
||||
err = process.Signal(syscall.SIGTERM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending SIGTERM to current process: %w", err)
|
||||
}
|
||||
// We're dying now, which should restart us
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func urlWithVersionArch(version string) string {
|
||||
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
|
||||
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||
}
|
||||
15
client/internal/updatemanager/update_freebsd.go
Normal file
15
client/internal/updatemanager/update_freebsd.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build freebsd
|
||||
|
||||
package updatemanager
|
||||
|
||||
import "context"
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// TODO: Implement
|
||||
return nil
|
||||
}
|
||||
15
client/internal/updatemanager/update_js.go
Normal file
15
client/internal/updatemanager/update_js.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build js
|
||||
|
||||
package updatemanager
|
||||
|
||||
import "context"
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// TODO: Implement
|
||||
return nil
|
||||
}
|
||||
15
client/internal/updatemanager/update_linux.go
Normal file
15
client/internal/updatemanager/update_linux.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build linux
|
||||
|
||||
package updatemanager
|
||||
|
||||
import "context"
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// TODO: Implement
|
||||
return nil
|
||||
}
|
||||
101
client/internal/updatemanager/update_windows.go
Normal file
101
client/internal/updatemanager/update_windows.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build windows
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||
|
||||
installerEXE installerType = "EXE"
|
||||
installerMSI installerType = "MSI"
|
||||
)
|
||||
|
||||
type installerType string
|
||||
|
||||
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
// Use test function if set (for testing purposes)
|
||||
if u.updateFunc != nil {
|
||||
return u.updateFunc(ctx, targetVersion)
|
||||
}
|
||||
|
||||
method := installation()
|
||||
return install(ctx, method, targetVersion)
|
||||
}
|
||||
|
||||
func installation() installerType {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, uninstallKeyPath64, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
k, err = registry.OpenKey(registry.LOCAL_MACHINE, uninstallKeyPath32, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return installerMSI
|
||||
} else {
|
||||
err = k.Close()
|
||||
if err != nil {
|
||||
log.Warnf("Error closing registry key: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err = k.Close()
|
||||
if err != nil {
|
||||
log.Warnf("Error closing registry key: %v", err)
|
||||
}
|
||||
}
|
||||
return installerEXE
|
||||
}
|
||||
|
||||
func install(ctx context.Context, installerType installerType, targetVersion string) error {
|
||||
path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(installerType, targetVersion))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("start installation %s", path)
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if installerType == installerEXE {
|
||||
cmd = exec.CommandContext(ctx, path, "/S")
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, "msiexec", "/quiet", "/i", path)
|
||||
}
|
||||
|
||||
// Detach the process from the parent
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Errorf("error starting installer: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cmd.Process.Release(); err != nil {
|
||||
log.Errorf("error releasing installer process: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("installer started successfully: %s", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
func urlWithVersionArch(it installerType, version string) string {
|
||||
var url string
|
||||
if it == installerEXE {
|
||||
url = exeDownloadURL
|
||||
} else {
|
||||
url = msiDownloadURL
|
||||
}
|
||||
url = strings.ReplaceAll(url, "%version", version)
|
||||
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||
}
|
||||
@@ -93,13 +93,14 @@ func main() {
|
||||
showLoginURL: flags.showLoginURL,
|
||||
showDebug: flags.showDebug,
|
||||
showProfiles: flags.showProfiles,
|
||||
showUpdate: flags.showUpdate,
|
||||
})
|
||||
|
||||
// Watch for theme/settings changes to update the icon.
|
||||
go watchSettingsChanges(a, client)
|
||||
|
||||
// Run in window mode if any UI flag was set.
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showUpdate {
|
||||
a.Run()
|
||||
return
|
||||
}
|
||||
@@ -127,6 +128,7 @@ type cliFlags struct {
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
errorMsg string
|
||||
showUpdate bool
|
||||
saveLogsInFile bool
|
||||
}
|
||||
|
||||
@@ -146,6 +148,7 @@ func parseFlags() *cliFlags {
|
||||
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
||||
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
||||
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
||||
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
|
||||
flag.Parse()
|
||||
return &flags
|
||||
}
|
||||
@@ -296,6 +299,8 @@ type serviceClient struct {
|
||||
mExitNodeDeselectAll *systray.MenuItem
|
||||
logFile string
|
||||
wLoginURL fyne.Window
|
||||
wUpdateProgress fyne.Window
|
||||
updateContextCancel context.CancelFunc
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -312,6 +317,7 @@ type newServiceClientArgs struct {
|
||||
showDebug bool
|
||||
showLoginURL bool
|
||||
showProfiles bool
|
||||
showUpdate bool
|
||||
}
|
||||
|
||||
// newServiceClient instance constructor
|
||||
@@ -329,7 +335,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
|
||||
showAdvancedSettings: args.showSettings,
|
||||
showNetworks: args.showNetworks,
|
||||
update: version.NewUpdate("nb/client-ui"),
|
||||
update: version.NewUpdateAndStart("nb/client-ui"),
|
||||
}
|
||||
|
||||
s.eventHandler = newEventHandler(s)
|
||||
@@ -347,6 +353,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
s.showDebugUI()
|
||||
case args.showProfiles:
|
||||
s.showProfilesUI()
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx)
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -392,6 +400,30 @@ func (s *serviceClient) updateIcon() {
|
||||
s.updateIndicationLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *serviceClient) showUpdateProgress(ctx context.Context) {
|
||||
s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
|
||||
loadingLabel := widget.NewLabel("Updating")
|
||||
s.wUpdateProgress.SetContent(container.NewGridWithRows(2, widget.NewLabel("Your client version is older than auto-update version set in Management, updating client now."), loadingLabel))
|
||||
s.wUpdateProgress.Show()
|
||||
go func() {
|
||||
dotCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
dotCount++
|
||||
dotCount %= 4
|
||||
loadingLabel.SetText(fmt.Sprintf("Updating%s", strings.Repeat(".", dotCount)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
s.wUpdateProgress.CenterOnScreen()
|
||||
s.wUpdateProgress.SetFixedSize(true)
|
||||
s.wUpdateProgress.SetCloseIntercept(func() {})
|
||||
s.wUpdateProgress.RequestFocus()
|
||||
}
|
||||
|
||||
func (s *serviceClient) showSettingsUI() {
|
||||
// Check if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
@@ -950,6 +982,29 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.updateExitNodes()
|
||||
}
|
||||
})
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
if windowAction, ok := event.Metadata["progress_window"]; ok {
|
||||
log.Debugf("window action: %v", windowAction)
|
||||
if windowAction == "show" {
|
||||
log.Debugf("Inside show")
|
||||
if s.updateContextCancel != nil {
|
||||
s.updateContextCancel()
|
||||
s.updateContextCancel = nil
|
||||
}
|
||||
|
||||
subCtx, cancel := context.WithCancel(s.ctx)
|
||||
go s.eventHandler.runSelfCommand(subCtx, "update", "true")
|
||||
s.updateContextCancel = cancel
|
||||
}
|
||||
if windowAction == "hide" {
|
||||
log.Debugf("Inside hide")
|
||||
if s.updateContextCancel != nil {
|
||||
s.updateContextCancel()
|
||||
s.updateContextCancel = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
|
||||
return "", err
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get post-up status: %v", err)
|
||||
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
return nil, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
|
||||
pm := profilemanager.NewProfileManager()
|
||||
var profName string
|
||||
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -62,7 +62,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||
echo ""
|
||||
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
fi
|
||||
|
||||
# Check if management identity provider is set
|
||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
||||
EXTRA_CONFIG={}
|
||||
|
||||
@@ -40,13 +40,21 @@ services:
|
||||
signal:
|
||||
<<: *default
|
||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||
depends_on:
|
||||
- dashboard
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
ports:
|
||||
- $NETBIRD_SIGNAL_PORT:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||
"--log-file", "console"
|
||||
]
|
||||
|
||||
# Relay
|
||||
relay:
|
||||
|
||||
@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
||||
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
||||
|
||||
s.update = version.NewUpdate("nb/management")
|
||||
s.update = version.NewUpdateAndStart("nb/management")
|
||||
s.update.SetDaemonVersion(version.NetbirdVersion())
|
||||
s.update.SetOnUpdateListener(func() {
|
||||
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
|
||||
|
||||
@@ -340,7 +340,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
|
||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
||||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
@@ -376,6 +377,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -477,6 +479,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{
|
||||
"version": newSettings.AutoUpdateVersion,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
|
||||
@@ -109,7 +109,7 @@ type Manager interface {
|
||||
GetIdpManager() idp.Manager
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
|
||||
@@ -180,6 +180,8 @@ const (
|
||||
UserApproved Activity = 89
|
||||
UserRejected Activity = 90
|
||||
|
||||
AccountAutoUpdateVersionUpdated Activity = 91
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -286,8 +288,11 @@ var activityMap = map[Activity]Code{
|
||||
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
|
||||
|
||||
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||
|
||||
UserApproved: {"User approved", "user.approve"},
|
||||
UserRejected: {"User rejected", "user.reject"},
|
||||
|
||||
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -712,6 +712,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
Fqdn: fqdn,
|
||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||
AutoUpdate: &proto.AutoUpdateSettings{
|
||||
Version: settings.AutoUpdateVersion,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -719,9 +722,10 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package accounts
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -26,7 +28,9 @@ const (
|
||||
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
|
||||
MinNetworkBitsIPv4 = 28
|
||||
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
|
||||
MinNetworkBitsIPv6 = 120
|
||||
MinNetworkBitsIPv6 = 120
|
||||
disableAutoUpdate = "disabled"
|
||||
autoUpdateLatestVersion = "latest"
|
||||
)
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
@@ -162,6 +166,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
|
||||
returnSettings := &types.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
returnSettings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Settings.JwtGroupsEnabled != nil {
|
||||
returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||
}
|
||||
if req.Settings.GroupsPropagationEnabled != nil {
|
||||
returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
|
||||
}
|
||||
if req.Settings.JwtGroupsClaimName != nil {
|
||||
returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||
}
|
||||
if req.Settings.JwtAllowGroups != nil {
|
||||
returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
|
||||
returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
|
||||
}
|
||||
if req.Settings.DnsDomain != nil {
|
||||
returnSettings.DNSDomain = *req.Settings.DnsDomain
|
||||
}
|
||||
if req.Settings.LazyConnectionEnabled != nil {
|
||||
returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
}
|
||||
if req.Settings.AutoUpdateVersion != nil {
|
||||
_, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion)
|
||||
if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion ||
|
||||
*req.Settings.AutoUpdateVersion == disableAutoUpdate ||
|
||||
err == nil {
|
||||
returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion
|
||||
} else if *req.Settings.AutoUpdateVersion != "" {
|
||||
return nil, fmt.Errorf("invalid AutoUpdateVersion")
|
||||
}
|
||||
}
|
||||
|
||||
return returnSettings, nil
|
||||
}
|
||||
|
||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
@@ -186,45 +245,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
settings := &types.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
settings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Settings.JwtGroupsEnabled != nil {
|
||||
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||
}
|
||||
if req.Settings.GroupsPropagationEnabled != nil {
|
||||
settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
|
||||
}
|
||||
if req.Settings.JwtGroupsClaimName != nil {
|
||||
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||
}
|
||||
if req.Settings.JwtAllowGroups != nil {
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
|
||||
settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
|
||||
}
|
||||
if req.Settings.DnsDomain != nil {
|
||||
settings.DNSDomain = *req.Settings.DnsDomain
|
||||
}
|
||||
if req.Settings.LazyConnectionEnabled != nil {
|
||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
settings, err := h.updateAccountRequestSettings(req)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
}
|
||||
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
|
||||
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
|
||||
@@ -313,6 +336,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
|
||||
@@ -120,6 +120,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -142,6 +143,30 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK with autoUpdateVersion",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr(""),
|
||||
JwtGroupsEnabled: br(false),
|
||||
JwtAllowGroups: &[]string{},
|
||||
RegularUsersViewBlocked: false,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -164,6 +189,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -186,6 +212,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -208,6 +235,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
h.setApprovalRequiredFlag(respBody, validPeersMap)
|
||||
h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
|
||||
for _, peer := range respBody {
|
||||
_, ok := approvedPeersMap[peer.Id]
|
||||
_, ok := validPeersMap[peer.Id]
|
||||
if !ok {
|
||||
peer.ApprovalRequired = true
|
||||
|
||||
reason := invalidPeersMap[peer.Id]
|
||||
peer.DisapprovalReason = &reason
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
}
|
||||
|
||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
|
||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
|
||||
osVersion := peer.Meta.OSVersion
|
||||
if osVersion == "" {
|
||||
osVersion = peer.Meta.Core
|
||||
}
|
||||
|
||||
return &api.Peer{
|
||||
apiPeer := &api.Peer{
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
|
||||
if !approved {
|
||||
apiPeer.DisapprovalReason = &reason
|
||||
}
|
||||
|
||||
return apiPeer
|
||||
}
|
||||
|
||||
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
|
||||
|
||||
@@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||
var err error
|
||||
var groups []*types.Group
|
||||
var peers []*nbpeer.Peer
|
||||
@@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
|
||||
|
||||
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||
validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return validPeers, invalidPeers, nil
|
||||
}
|
||||
|
||||
type MockIntegratedValidator struct {
|
||||
@@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) {
|
||||
return make(map[string]string), nil
|
||||
}
|
||||
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ type IntegratedValidator interface {
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||
SetPeerInvalidationListener(fn func(accountID string, peerIDs []string))
|
||||
Stop(ctx context.Context)
|
||||
|
||||
@@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
|
||||
account, err := am.GetAccountFunc(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
approvedPeers[id] = struct{}{}
|
||||
}
|
||||
return approvedPeers, nil
|
||||
return approvedPeers, nil, nil
|
||||
}
|
||||
|
||||
// GetGroup mock implementation of GetGroup from server.AccountManager interface
|
||||
|
||||
@@ -1575,6 +1575,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
PeerConfig: toPeerConfig(peer, network, dnsDomain, settings),
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
|
||||
@@ -52,6 +52,9 @@ type Settings struct {
|
||||
|
||||
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
|
||||
LazyConnectionEnabled bool `gorm:"default:false"`
|
||||
|
||||
// AutoUpdateVersion client auto-update version
|
||||
AutoUpdateVersion string `gorm:"default:'latest'"`
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
|
||||
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
||||
DNSDomain: s.DNSDomain,
|
||||
NetworkRange: s.NetworkRange,
|
||||
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
|
||||
@@ -145,6 +145,10 @@ components:
|
||||
description: Enables or disables experimental lazy connection
|
||||
type: boolean
|
||||
example: true
|
||||
auto_update_version:
|
||||
description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
|
||||
type: string
|
||||
example: "0.51.2"
|
||||
required:
|
||||
- peer_login_expiration_enabled
|
||||
- peer_login_expiration
|
||||
@@ -463,6 +467,9 @@ components:
|
||||
description: (Cloud only) Indicates whether peer needs approval
|
||||
type: boolean
|
||||
example: true
|
||||
disapproval_reason:
|
||||
description: (Cloud only) Reason why the peer requires approval
|
||||
type: string
|
||||
country_code:
|
||||
$ref: '#/components/schemas/CountryCode'
|
||||
city_name:
|
||||
|
||||
@@ -291,6 +291,9 @@ type AccountRequest struct {
|
||||
|
||||
// AccountSettings defines model for AccountSettings.
|
||||
type AccountSettings struct {
|
||||
// AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
|
||||
AutoUpdateVersion *string `json:"auto_update_version,omitempty"`
|
||||
|
||||
// DnsDomain Allows to define a custom dns domain for the account
|
||||
DnsDomain *string `json:"dns_domain,omitempty"`
|
||||
Extra *AccountExtraSettings `json:"extra,omitempty"`
|
||||
@@ -1037,6 +1040,9 @@ type Peer struct {
|
||||
// CreatedAt Peer creation date (UTC)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// DisapprovalReason (Cloud only) Reason why the peer requires approval
|
||||
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
|
||||
|
||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DnsLabel string `json:"dns_label"`
|
||||
|
||||
@@ -1124,6 +1130,9 @@ type PeerBatch struct {
|
||||
// CreatedAt Peer creation date (UTC)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// DisapprovalReason (Cloud only) Reason why the peer requires approval
|
||||
DisapprovalReason *string `json:"disapproval_reason,omitempty"`
|
||||
|
||||
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DnsLabel string `json:"dns_label"`
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -266,6 +266,18 @@ message PeerConfig {
|
||||
bool LazyConnectionEnabled = 6;
|
||||
|
||||
int32 mtu = 7;
|
||||
|
||||
// Auto-update config
|
||||
AutoUpdateSettings autoUpdate = 8;
|
||||
}
|
||||
|
||||
message AutoUpdateSettings {
|
||||
string version = 1;
|
||||
/*
|
||||
alwaysUpdate = true → Updates happen automatically in the background
|
||||
alwaysUpdate = false → Updates only happen when triggered by a peer connection
|
||||
*/
|
||||
bool alwaysUpdate = 2;
|
||||
}
|
||||
|
||||
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
||||
|
||||
@@ -94,7 +94,7 @@ var (
|
||||
|
||||
startPprof()
|
||||
|
||||
opts, certManager, err := getTLSConfigurations()
|
||||
opts, certManager, tlsConfig, err := getTLSConfigurations()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -132,7 +132,7 @@ var (
|
||||
|
||||
// Start the main server - always serve HTTP with WebSocket proxy support
|
||||
// If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager
|
||||
if certManager == nil {
|
||||
if tlsConfig == nil {
|
||||
// Without TLS, serve plain HTTP
|
||||
httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort))
|
||||
if err != nil {
|
||||
@@ -140,9 +140,10 @@ var (
|
||||
}
|
||||
log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String())
|
||||
serveHTTP(httpListener, grpcRootHandler)
|
||||
} else if signalPort != 443 {
|
||||
// With TLS but not on port 443, serve HTTPS
|
||||
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig())
|
||||
} else if certManager == nil || signalPort != 443 {
|
||||
// Serve HTTPS if not already handled by startServerWithCertManager
|
||||
// (custom certificates or Let's Encrypt with custom port)
|
||||
httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -202,7 +203,7 @@ func startPprof() {
|
||||
}()
|
||||
}
|
||||
|
||||
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
|
||||
func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) {
|
||||
var (
|
||||
err error
|
||||
certManager *autocert.Manager
|
||||
@@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) {
|
||||
|
||||
if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" {
|
||||
log.Infof("running without TLS")
|
||||
return nil, nil, nil
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
|
||||
if signalLetsencryptDomain != "" {
|
||||
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
|
||||
if err != nil {
|
||||
return nil, certManager, err
|
||||
return nil, certManager, nil, err
|
||||
}
|
||||
tlsConfig = certManager.TLSConfig()
|
||||
log.Infof("setting up TLS with LetsEncrypt.")
|
||||
} else {
|
||||
if signalCertFile == "" || signalCertKey == "" {
|
||||
log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||
return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||
return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt")
|
||||
}
|
||||
|
||||
tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey)
|
||||
if err != nil {
|
||||
log.Errorf("cannot load TLS credentials: %v", err)
|
||||
return nil, certManager, err
|
||||
return nil, certManager, nil, err
|
||||
}
|
||||
log.Infof("setting up TLS with custom certificates.")
|
||||
}
|
||||
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
|
||||
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err
|
||||
return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err
|
||||
}
|
||||
|
||||
func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {
|
||||
|
||||
@@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update {
|
||||
currentVersion, _ = goversion.NewVersion("0.0.0")
|
||||
}
|
||||
|
||||
latestAvailable, _ := goversion.NewVersion("0.0.0")
|
||||
|
||||
u := &Update{
|
||||
httpAgent: httpAgent,
|
||||
latestAvailable: latestAvailable,
|
||||
uiVersion: currentVersion,
|
||||
fetchTicker: time.NewTicker(fetchPeriod),
|
||||
fetchDone: make(chan struct{}),
|
||||
httpAgent: httpAgent,
|
||||
uiVersion: currentVersion,
|
||||
fetchDone: make(chan struct{}),
|
||||
}
|
||||
go u.startFetcher()
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func NewUpdateAndStart(httpAgent string) *Update {
|
||||
u := NewUpdate(httpAgent)
|
||||
go u.StartFetcher()
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
// StopWatch stop the version info fetch loop
|
||||
func (u *Update) StopWatch() {
|
||||
if u.fetchTicker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.fetchTicker.Stop()
|
||||
|
||||
select {
|
||||
@@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) {
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Update) startFetcher() {
|
||||
func (u *Update) LatestVersion() *goversion.Version {
|
||||
u.versionsLock.Lock()
|
||||
defer u.versionsLock.Unlock()
|
||||
return u.latestAvailable
|
||||
}
|
||||
|
||||
func (u *Update) StartFetcher() {
|
||||
if u.fetchTicker != nil {
|
||||
return
|
||||
}
|
||||
u.fetchTicker = time.NewTicker(fetchPeriod)
|
||||
|
||||
if changed := u.fetchVersion(); changed {
|
||||
u.checkUpdate()
|
||||
}
|
||||
@@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool {
|
||||
u.versionsLock.Lock()
|
||||
defer u.versionsLock.Unlock()
|
||||
|
||||
if u.latestAvailable == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if u.latestAvailable.GreaterThan(u.uiVersion) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate(httpAgent)
|
||||
u := NewUpdateAndStart(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
@@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate(httpAgent)
|
||||
u := NewUpdateAndStart(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
@@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
|
||||
wg.Add(1)
|
||||
|
||||
onUpdate := false
|
||||
u := NewUpdate(httpAgent)
|
||||
u := NewUpdateAndStart(httpAgent)
|
||||
defer u.StopWatch()
|
||||
u.SetOnUpdateListener(func() {
|
||||
onUpdate = true
|
||||
|
||||
Reference in New Issue
Block a user