Compare commits

..

1 Commits

Author SHA1 Message Date
bcmmbaga
5b344f9b3f test migration
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-11 22:04:00 +03:00
63 changed files with 763 additions and 2541 deletions

View File

@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}

View File

@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -124,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
id: go
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
@@ -49,4 +49,4 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
version: latest
args: --timeout=12m
args: --timeout=12m

View File

@@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: run install script
env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
uses: actions/setup-java@v3
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
@@ -62,4 +62,4 @@ jobs:
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0
CGO_ENABLED: 0

View File

@@ -36,18 +36,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23"
go-version: "1.21"
cache: false
-
name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
@@ -93,28 +93,28 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -133,17 +133,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23"
go-version: "1.21"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
@@ -176,7 +176,7 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: release-ui
path: dist/
@@ -189,18 +189,18 @@ jobs:
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
-
name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23"
go-version: "1.21"
cache: false
-
name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
@@ -225,7 +225,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: release-ui-darwin
path: dist/

View File

@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,7 +219,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -117,11 +117,6 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
@@ -164,17 +159,13 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
err = WriteOutConfig(input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input)
}

View File

@@ -158,7 +158,6 @@ func (c *ConnectClient) run(
}
defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.isContextCancelled() {
@@ -268,12 +267,6 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
if c.engine != nil && c.engine.ctx.Err() != nil {
log.Info("Stopping Netbird Engine")
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
}
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -286,10 +279,9 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil && runningChanOpen {
if runningChan != nil {
runningChan <- nil
close(runningChan)
runningChanOpen = false
}
<-engineCtx.Done()

View File

@@ -1115,7 +1115,10 @@ func (e *Engine) close() {
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
if e.routeManager != nil {
e.routeManager.Stop()
@@ -1357,16 +1360,12 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
}
func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated")
e.clientCancel()
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
}
func (e *Engine) startNetworkMonitor() {
@@ -1388,7 +1387,6 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
@@ -1398,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
log.Infof("Network monitor detected network change, restarting engine")
e.restartEngine()
})
})
@@ -1423,20 +1421,6 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
nsGroupStates[i].Enabled = false
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
if e.dnsServer != nil {
e.dnsServer.Stop()
e.dnsServer = nil
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {

View File

@@ -89,8 +89,8 @@ type Conn struct {
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string, wgIP string)
statusRelay *AtomicConnStatus
statusICE *AtomicConnStatus
statusRelay ConnStatus
statusICE ConnStatus
currentConnPriority ConnPriority
opened bool // this flag is used to prevent close in case of not opened connection
@@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler,
relayManager: relayManager,
allowedIPsIP: allowedIPsIP.String(),
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
statusRelay: StatusDisconnected,
statusICE: StatusDisconnected,
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
@@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
}
} else {
if conn.statusICE.Get() == StatusDisconnected {
if conn.statusICE == StatusDisconnected {
conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
}
}
@@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready")
conn.statusICE.Set(StatusConnected)
conn.statusICE = StatusConnected
defer conn.updateIceState(iceConnInfo)
@@ -492,8 +492,8 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.currentConnPriority = connPriorityRelay
}
changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState)
changed := conn.statusICE != newState && newState != StatusConnecting
conn.statusICE = newState
select {
case conn.iCEDisconnected <- changed:
@@ -518,14 +518,11 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return
}
conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay.Set(StatusConnected)
conn.statusRelay = StatusConnected
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx)
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn)
@@ -533,7 +530,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr
@@ -542,7 +538,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
if conn.currentConnPriority > connPriorityRelay {
if conn.statusICE.Get() == StatusConnected {
if conn.statusICE == StatusConnected {
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
return
}
@@ -563,8 +559,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("Failed to update wg peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround()
conn.workerRelay.EnableWgWatcher(conn.ctx)
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
@@ -598,8 +594,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
conn.wgProxyRelay = nil
}
changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected)
changed := conn.statusRelay != StatusDisconnected
conn.statusRelay = StatusDisconnected
select {
case conn.relayDisconnected <- changed:
@@ -665,8 +661,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
}
func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay.Set(StatusDisconnected)
conn.statusICE.Set(StatusDisconnected)
conn.statusRelay = StatusDisconnected
conn.statusICE = StatusDisconnected
peerState := State{
PubKey: conn.config.Key,
@@ -710,7 +706,7 @@ func (conn *Conn) waitInitialRandomSleepTime() {
}
func (conn *Conn) isRelayed() bool {
if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) {
return false
}
@@ -722,11 +718,11 @@ func (conn *Conn) isRelayed() bool {
}
func (conn *Conn) evalStatus() ConnStatus {
if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected {
return StatusConnected
}
if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting {
return StatusConnecting
}
@@ -737,12 +733,12 @@ func (conn *Conn) isConnected() bool {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting {
return false
}
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() != StatusConnected {
if conn.statusRelay != StatusConnected {
return false
}
}
@@ -779,8 +775,9 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
err = wgProxy.CloseConn()
if err != nil {
conn.log.Warnf("failed to close turn proxy connection: %v", err)
}
return nil, nil, err
}

View File

@@ -1,10 +1,6 @@
package peer
import (
"sync/atomic"
log "github.com/sirupsen/logrus"
)
import log "github.com/sirupsen/logrus"
const (
// StatusConnected indicate the peer is in connected state
@@ -16,34 +12,7 @@ const (
)
// ConnStatus describe the status of a peer's connection
type ConnStatus int32
// AtomicConnStatus is a thread-safe wrapper for ConnStatus
type AtomicConnStatus struct {
status atomic.Int32
}
// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
func NewAtomicConnStatus() *AtomicConnStatus {
acs := &AtomicConnStatus{}
acs.Set(StatusDisconnected)
return acs
}
// Get returns the current connection status
func (acs *AtomicConnStatus) Get() ConnStatus {
return ConnStatus(acs.status.Load())
}
// Set updates the connection status
func (acs *AtomicConnStatus) Set(status ConnStatus) {
acs.status.Store(int32(status))
}
// String returns the string representation of the current status
func (acs *AtomicConnStatus) String() string {
return acs.Get().String()
}
type ConnStatus int
func (s ConnStatus) String() string {
switch s {

View File

@@ -158,13 +158,8 @@ func TestConn_Status(t *testing.T) {
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
si := NewAtomicConnStatus()
si.Set(table.statusIce)
conn.statusICE = si
sr := NewAtomicConnStatus()
sr.Set(table.statusRelay)
conn.statusRelay = sr
conn.statusICE = table.statusIce
conn.statusRelay = table.statusRelay
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")

View File

@@ -597,8 +597,6 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
d.mux.Lock()
defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -606,8 +604,6 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetManagementState() ManagementState {
d.mux.Lock()
defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -649,8 +645,6 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
d.mux.Lock()
defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -660,8 +654,6 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.Lock()
defer d.mux.Unlock()
if d.relayMgr == nil {
return d.relayStates
}
@@ -692,8 +684,6 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
return d.nsGroupStates
}
@@ -705,19 +695,18 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
d.mux.Lock()
defer d.mux.Unlock()
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
d.mux.Lock()
defer d.mux.Unlock()
fullStatus.LocalPeerState = d.localPeer
for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status)
}

View File

@@ -14,7 +14,7 @@ import (
)
var (
wgHandshakePeriod = 3 * time.Minute
wgHandshakePeriod = 2 * time.Minute
wgHandshakeOvertime = 30 * time.Second
)
@@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
}
ctx, ctxCancel := context.WithCancel(ctx)
go w.wgStateCheck(ctx)
w.ctxWgWatch = ctx
w.ctxCancelWgWatch = ctxCancel
w.wgStateCheck(ctx, ctxCancel)
}
func (w *WorkerRelay) DisableWgWatcher() {
@@ -157,51 +157,37 @@ func (w *WorkerRelay) CloseConn() {
}
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
// wgStateCheck help to check the state of the wireguard handshake and relay connection
func (w *WorkerRelay) wgStateCheck(ctx context.Context) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
expected := wgHandshakeOvertime
for {
select {
case <-timer.C:
lastHandshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
w.log.Tracef("last handshake: %v", lastHandshake)
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgHandshakeOvertime)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
if time.Since(lastHandshake) > expected {
w.log.Infof("Wireguard handshake timed out, closing relay connection")
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.callBacks.OnDisconnected()
return
}
resetTime := time.Until(lastHandshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
timer.Reset(resetTime)
expected = wgHandshakePeriod
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
}(lastHandshake)
}
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {

View File

@@ -32,8 +32,8 @@ func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
p.remoteConn = remoteConn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
@@ -54,14 +54,6 @@ func (p *WGUserSpaceProxy) CloseConn() error {
if p.localConn == nil {
return nil
}
if p.remoteConn == nil {
return nil
}
if err := p.remoteConn.Close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
return p.localConn.Close()
}
@@ -73,8 +65,6 @@ func (p *WGUserSpaceProxy) Free() error {
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500)
for {
select {
@@ -103,8 +93,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
defer p.cancel()
defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr())
buf := make([]byte, 1500)
for {
select {
@@ -114,6 +103,7 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)

View File

@@ -1,41 +1,13 @@
package main
import (
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/cmd"
)
func main() {
// only if env is not set
if os.Getenv("NB_LOG_LEVEL") == "" {
if err := os.Setenv("NB_LOG_LEVEL", "debug"); err != nil {
log.Errorf("Failed setting log-level: %v", err)
}
}
if err := os.Setenv("NB_LOG_MAX_SIZE_MB", "100"); err != nil {
log.Errorf("Failed setting log-size: %v", err)
}
if err := os.Setenv("NB_WINDOWS_PANIC_LOG", filepath.Join(os.Getenv("ProgramData"), "netbird", "netbird.err")); err != nil {
log.Errorf("Failed setting panic log path: %v", err)
}
go startPprofServer()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}
func startPprofServer() {
pprofAddr := "localhost:6969"
log.Infof("Starting pprof debugging server on %s", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Infof("pprof server failed: %v", err)
}
}

View File

@@ -110,35 +110,6 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if statusResp, err := s.Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}); err != nil {
log.Infof("Error getting status: %v", err)
} else if statusResp.FullStatus != nil {
log.Infof("Status --------")
for _, peer := range statusResp.FullStatus.Peers {
log.Infof("[Peer Connection] Name: %s, IP: %s, Key: %s, Connection Status: %s, Relayed: %v, RelayedAddress: %v, Last WireGuard Handshake: %v",
peer.Fqdn,
peer.IP,
peer.PubKey,
peer.ConnStatus,
peer.Relayed,
peer.RelayAddress,
peer.LastWireguardHandshake.AsTime().Format("15:04:05"),
)
}
}
}
}
}()
// if configuration exists, we just start connections. if is new config we skip and set status NeedsLogin
// on failure we return error to retry
config, err := internal.UpdateConfig(s.latestConfigInput)

2
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird
go 1.23.0
go 1.21.0
require (
cunicu.li/go-rosenpass v0.4.0

View File

@@ -54,7 +54,7 @@ func Execute() error {
func init() {
stopCh = make(chan int)
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")

View File

@@ -263,11 +263,6 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
// Subclass used in gorm to only load network and not whole account
type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -705,6 +700,14 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
@@ -731,6 +734,14 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
@@ -2071,7 +2082,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
if err != nil {
return false, err
}
@@ -2092,25 +2103,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
return false, nil
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
labelMap := ConvertSliceToMap(existingLabels)
newLabel, err := getPeerHostLabel(peerHostName, labelMap)
if err != nil {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
if newLabel == "" {
return "", fmt.Errorf("failed to get new host label: %w", err)
}
return newLabel, nil
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {

View File

@@ -0,0 +1,23 @@
package main
import (
"context"
_ "github.com/mattn/go-sqlite3"
"github.com/netbirdio/netbird/management/server/activity/sqlite"
log "github.com/sirupsen/logrus"
)
func main() {
encryptionKey := "<enc_key>"
eventsDBBase := "management/server/activity/cmd/events.db"
store, err := sqlite.NewSQLiteStore(context.Background(), eventsDBBase, encryptionKey)
if err != nil {
log.Fatalf("failed to create sqlite store: %v", err)
}
if err = store.GetLegacyEvents(); err != nil {
log.Fatalf("failed to get legacy events: %v", err)
}
}

View File

@@ -7,6 +7,7 @@ import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
@@ -114,23 +115,12 @@ func pkcs5Padding(ciphertext []byte) []byte {
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
if srcLen == 0 {
return nil, errors.New("input data is empty")
}
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
return nil, fmt.Errorf("padding size error")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -1,7 +1,6 @@
package sqlite
import (
"bytes"
"testing"
)
@@ -96,215 +95,3 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,78 @@
package sqlite
import (
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
)
func (store *Store) GetLegacyEvents() error {
rows, err := store.db.Query(`SELECT id, email, name FROM deleted_users`)
if err != nil {
return fmt.Errorf("failed to execute select query: %v", err)
}
defer rows.Close()
if err = processLegacyEvents(store.fieldEncrypt, rows); err != nil {
return err
}
return nil
}
// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
func processLegacyEvents(crypt *FieldEncrypt, rows *sql.Rows) error {
var (
successCount int
failureCount int
)
for rows.Next() {
var (
id string
email, name *string
)
err := rows.Scan(&id, &email, &name)
if err != nil {
return err
}
if email != nil {
_, err = crypt.LegacyDecrypt(*email)
if err != nil {
log.Warnf("failed to decrypt email for user %s: %v",
id,
fmt.Errorf("failed to decrypt email: %w", err),
)
failureCount++
continue
}
}
if name != nil {
_, err = crypt.LegacyDecrypt(*name)
if err != nil {
log.Warnf("failed to decrypt name for user %s: %v",
id,
fmt.Errorf("failed to decrypt name: %w", err),
)
failureCount++
continue
}
}
successCount++
}
if err := rows.Err(); err != nil {
return err
}
log.Infof("Successfully decoded entries: %d", successCount)
log.Infof("Failed decoded entries: %d", failureCount)
return nil
}

View File

@@ -5,7 +5,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"path/filepath"
"os"
"time"
_ "github.com/mattn/go-sqlite3"
@@ -26,7 +26,7 @@ const (
"meta TEXT," +
" target_id TEXT);"
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events
@@ -69,7 +69,7 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
@@ -89,9 +89,15 @@ type Store struct {
}
// NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
func NewSQLiteStore(ctx context.Context, dbPath string, encryptionKey string) (*Store, error) {
//dbFile := filepath.Join(dataDir, eventSinkDB)
stats, err := os.Stat(dbPath)
if err != nil {
return nil, err
}
_ = stats
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
}
@@ -102,10 +108,10 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err
}
if err = migrate(ctx, crypt, db); err != nil {
_ = db.Close()
return nil, fmt.Errorf("events database migration: %w", err)
}
//if err = migrate(ctx, crypt, db); err != nil {
// _ = db.Close()
// return nil, fmt.Errorf("events database migration: %w", err)
//}
return createStore(crypt, db)
}

View File

@@ -7,7 +7,6 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type MockStore struct {
@@ -25,7 +24,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}
return nil, status.NewPeerNotFoundError(peerId)
return nil, fmt.Errorf("account not found")
}
type MocAccountManager struct {

View File

@@ -2,8 +2,6 @@ package server
import (
"context"
"errors"
"net"
"os"
"path/filepath"
"strings"
@@ -48,158 +46,6 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
return f(s)
}
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
if !ok {
return status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return err
}
account.SetupKeys[setupKeyID].UsedTimes++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
allGroup, err := account.GetGroupAll()
if err != nil || allGroup == nil {
return errors.New("all group not found")
}
allGroup.Peers = append(allGroup.Peers, peerID)
return nil
}
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountId)
if err != nil {
return err
}
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
return nil
}
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[peer.AccountID]
if !ok {
return status.NewAccountNotFoundError(peer.AccountID)
}
account.Peers[peer.ID] = peer
return s.SaveAccount(ctx, account)
}
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
s.mux.Lock()
defer s.mux.Unlock()
account, ok := s.Accounts[accountId]
if !ok {
return status.NewAccountNotFoundError(accountId)
}
account.Network.Serial++
return s.SaveAccount(ctx, account)
}
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
setupKey, ok := account.SetupKeys[key]
if !ok {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return setupKey, nil
}
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
var takenIps []net.IP
for _, existingPeer := range account.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps, nil
}
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
existingLabels := []string{}
for _, peer := range account.Peers {
if peer.DNSLabel != "" {
existingLabels = append(existingLabels, peer.DNSLabel)
}
}
return existingLabels, nil
}
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Network, nil
}
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
@@ -576,7 +422,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return nil, status.NewSetupKeyNotFoundError()
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}
account, err := s.getAccount(accountID)
@@ -623,7 +469,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
@@ -667,7 +513,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
return nil, status.NewAccountNotFoundError(accountID)
return nil, status.Errorf(status.NotFound, "account not found")
}
return account, nil
@@ -793,13 +639,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.NewSetupKeyNotFoundError()
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -822,7 +668,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, pe
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -912,7 +758,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()

View File

@@ -251,7 +251,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
Peer:
PeerBase:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
- type: object
@@ -378,40 +378,25 @@ components:
description: User ID of the user that enrolled this peer
type: string
example: google-oauth2|277474792786460067937
os:
description: Peer's operating system and version
type: string
example: linux
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
geoname_id:
description: Unique identifier from the GeoNames database for a specific geographical location.
type: integer
example: 2643743
connected:
description: Peer to Management connection status
type: boolean
example: true
last_seen:
description: Last time peer connected to Netbird's management service
type: string
format: date-time
example: "2023-05-05T10:05:26.420578Z"
required:
- ip
- dns_label
- user_id
- os
- country_code
- city_name
- geoname_id
- connected
- last_seen
Peer:
allOf:
- $ref: '#/components/schemas/PeerBase'
- type: object
properties:
accessible_peers:
description: List of accessible peers
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
required:
- accessible_peers
PeerBatch:
allOf:
- $ref: '#/components/schemas/Peer'
- $ref: '#/components/schemas/PeerBase'
- type: object
properties:
accessible_peers_count:
@@ -1821,38 +1806,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers/{peerId}/accessible-peers:
get:
summary: List accessible Peers
description: Returns a list of peers that the specified peer can connect to within the network.
tags: [ Peers ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: peerId
required: true
schema:
type: string
description: The unique identifier of a peer
responses:
'200':
description: A JSON Array of Accessible Peers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/AccessiblePeer'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup-keys:
get:
summary: List all Setup Keys

View File

@@ -152,36 +152,18 @@ const (
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
}
@@ -508,6 +490,81 @@ type OSVersionCheck struct {
// Peer defines model for Peer.
type Peer struct {
// AccessiblePeers List of accessible peers
AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city
CityName CityName `json:"city_name"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp string `json:"connection_ip"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId int `json:"geoname_id"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
// Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion string `json:"kernel_version"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
// LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"`
// LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
LoginExpirationEnabled bool `json:"login_expiration_enabled"`
// LoginExpired Indicates whether peer's login expired or not
LoginExpired bool `json:"login_expired"`
// Name Peer's hostname
Name string `json:"name"`
// Os Peer's operating system and version
Os string `json:"os"`
// SerialNumber System serial number
SerialNumber string `json:"serial_number"`
// SshEnabled Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// UiVersion Peer's desktop UI version
UiVersion string `json:"ui_version"`
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
// Version Peer's daemon or cli version
Version string `json:"version"`
}
// PeerBase defines model for PeerBase.
type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`

View File

@@ -115,7 +115,6 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {

View File

@@ -71,8 +71,12 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -113,9 +117,13 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -212,66 +220,32 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
}
}
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
}
for _, p := range netMap.OfflinePeers {
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
ap := api.AccessiblePeer{
Id: p.ID,
Name: p.Name,
Ip: p.IP.String(),
DnsLabel: fqdn(p, dnsDomain),
UserId: p.UserID,
}
accessiblePeers = append(accessiblePeers, ap)
}
return accessiblePeers
}
func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
return api.AccessiblePeer{
CityName: peer.Location.CityName,
Connected: peer.Status.Connected,
CountryCode: peer.Location.CountryCode,
DnsLabel: fqdn(peer, dnsDomain),
GeonameId: int(peer.Location.GeoNameID),
Id: peer.ID,
Ip: peer.IP.String(),
LastSeen: peer.Status.LastSeen,
Name: peer.Name,
Os: peer.Meta.OS,
UserId: peer.UserID,
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
@@ -296,7 +270,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
@@ -322,6 +296,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,

View File

@@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) {
}
func Test_LoginPerformance(t *testing.T) {
if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
t.Skip("Skipping test on CI or Windows")
if os.Getenv("CI") == "true" {
t.Skip("Skipping on CI")
}
t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
@@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) {
// {"M", 250, 1},
// {"L", 500, 1},
// {"XL", 750, 1},
{"XXL", 5000, 1},
{"XXL", 2000, 1},
}
log.SetOutput(io.Discard)
@@ -700,18 +700,15 @@ func Test_LoginPerformance(t *testing.T) {
}
defer mgmtServer.GracefulStop()
t.Logf("management setup complete, start registering peers")
var counter int32
var counterStart int32
var wgAccount sync.WaitGroup
var wg sync.WaitGroup
var mu sync.Mutex
messageCalls := []func() error{}
for j := 0; j < bc.accounts; j++ {
wgAccount.Add(1)
var wgPeer sync.WaitGroup
wg.Add(1)
go func(j int, counter *int32, counterStart *int32) {
defer wgAccount.Done()
defer wg.Done()
account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
if err != nil {
@@ -725,9 +722,7 @@ func Test_LoginPerformance(t *testing.T) {
return
}
startTime := time.Now()
for i := 0; i < bc.peers; i++ {
wgPeer.Add(1)
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Logf("failed to generate key: %v", err)
@@ -768,29 +763,21 @@ func Test_LoginPerformance(t *testing.T) {
mu.Lock()
messageCalls = append(messageCalls, login)
mu.Unlock()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
go func(peerLogin PeerLogin, counterStart *int32) {
defer wgPeer.Done()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return
}
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}(peerLogin, counterStart)
atomic.AddInt32(counterStart, 1)
if *counterStart%100 == 0 {
t.Logf("registered %d peers", *counterStart)
}
}
wgPeer.Wait()
t.Logf("Time for registration: %s", time.Since(startTime))
}(j, &counter, &counterStart)
}
wgAccount.Wait()
wg.Wait()
t.Logf("prepared %d login calls", len(messageCalls))
testLoginPerformance(t, messageCalls)

View File

@@ -11,7 +11,6 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto"
@@ -372,175 +371,164 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
_, err = account.FindPeerByPubKey(peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
AccountID: accountID,
AccountID: account.Id,
}
var newPeer *nbpeer.Peer
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var groupsToAdd []string
var setupKeyID string
var setupKeyName string
var ephemeral bool
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
} else {
// Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
if !sk.IsValid() {
return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
groupsToAdd = sk.AutoGroups
ephemeral = sk.Ephemeral
setupKeyID = sk.Id
setupKeyName = sk.Name
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
}
}
freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
var ephemeral bool
setupKeyName := ""
if !addedByUser {
// validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey)
if err != nil {
return fmt.Errorf("failed to get free DNS label: %w", err)
return nil, nil, nil, err
}
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
if !sk.IsValid() {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
registrationTime := time.Now().UTC()
newPeer = &nbpeer.Peer{
ID: xid.New().String(),
AccountID: accountID,
Key: peer.Key,
SetupKey: upperKey,
IP: freeIP,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: freeLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
account.SetupKeys[sk.Key] = sk.IncrementUsage()
opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey
ephemeral = sk.Ephemeral
setupKeyName = sk.Name
} else {
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
if err != nil {
return fmt.Errorf("failed to update user last login: %w", err)
}
} else {
err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
if err != nil {
return fmt.Errorf("failed to increment setup key usage: %w", err)
}
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
takenIps := account.getTakenIPs()
existingLabels := account.getPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
return nil, nil, nil, err
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
peer.DNSLabel = newLabel
network := account.Network
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, nil, nil, err
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
registrationTime := time.Now().UTC()
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
Key: peer.Key,
SetupKey: upperKey,
IP: nextIp,
Meta: peer.Meta,
Name: peer.Meta.Hostname,
DNSLabel: newLabel,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
SSHEnabled: false,
SSHKey: peer.SSHKey,
LastLogin: registrationTime,
CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral,
Location: peer.Location,
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
} else {
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, err
}
group.Peers = append(group.Peers, newPeer.ID)
var groupsToAdd []string
if addedByUser {
groupsToAdd, err = account.getUserGroups(userID)
if err != nil {
return nil, nil, nil, err
}
} else {
groupsToAdd, err = account.getSetupKeyGroups(upperKey)
if err != nil {
return nil, nil, nil, err
}
}
if len(groupsToAdd) > 0 {
for _, s := range groupsToAdd {
if g, ok := account.Groups[s]; ok && g.Name != "All" {
g.Peers = append(g.Peers, newPeer.ID)
}
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
if addedByUser {
user, err := account.FindUser(userID)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
}
user.updateLastLogin(newPeer.LastLogin)
}
account.Peers[newPeer.ID] = newPeer
account.Network.IncSerial()
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, nil, nil, err
}
// Account is saved, we can release the lock
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -548,31 +536,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, newPeer)
postureChecks := am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
return nextIp, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -678,12 +647,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -761,7 +730,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -772,7 +741,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return err
}
@@ -817,7 +786,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return err
}
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
@@ -1000,11 +969,3 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}

View File

@@ -7,24 +7,20 @@ import (
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
)
@@ -999,184 +995,3 @@ func TestToSyncResponse(t *testing.T) {
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}
func Test_RegisterPeerByUser(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
UserID: existingUserID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
LastLogin: time.Now(),
}
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
}
func Test_RegisterPeerBySetupKey(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.SetupKey, existingSetupKeyID)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
assert.Equal(t, uint64(1), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
ID: xid.New().String(),
AccountID: existingAccountID,
Key: "newPeerKey",
SetupKey: "existingSetupKey",
UserID: "",
IP: net.IP{123, 123, 123, 123},
Meta: nbpeer.PeerSystemMeta{
Hostname: "newPeer",
GoOS: "linux",
},
Name: "newPeerName",
DNSLabel: "newPeer.test",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
assert.Equal(t, uint64(0), account.Network.Serial)
lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
assert.NoError(t, err)
assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
@@ -34,7 +33,6 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -417,12 +415,13 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return nil, status.NewSetupKeyNotFoundError()
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
}
if key.AccountID == "" {
@@ -475,15 +474,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
result := s.db.First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
}
return nil, status.NewGetUserFromStoreError()
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting user from store")
}
return &user, nil
@@ -536,7 +535,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
return nil, status.Errorf(status.NotFound, "account not found")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -596,7 +595,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -613,11 +612,12 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -631,11 +631,12 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -649,11 +650,12 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@@ -675,117 +677,61 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.NewSetupKeyNotFoundError()
}
if accountID == "" {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
}
return accountID, nil
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("ip", &ipJSONStrings)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
}
// Convert the JSON strings to net.IP objects
ips := make([]net.IP, len(ipJSONStrings))
for i, ipJSON := range ipJSONStrings {
var ip net.IP
if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
}
ips[i] = ip
}
return ips, nil
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
}
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store")
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
result := s.db.First(&peer, "key = ?", peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
return status.Errorf(status.NotFound, "user %s not found", userID)
}
return status.NewGetUserFromStoreError()
return status.Errorf(status.Internal, "issue getting user from store")
}
user.LastLogin = lastLogin
return s.db.Save(&user).Error
return s.db.Save(user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -904,123 +850,3 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, strings.ToUpper(key))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return nil, status.NewSetupKeyNotFoundError()
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
"last_used": time.Now(),
})
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
}
return nil
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All'")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All'")
}
return nil
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
return status.Errorf(status.Internal, "issue finding group")
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
}
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group")
}
return nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count")
}
return nil
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return tx.Error
}
repo := s.withTx(tx)
err := operation(repo)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
db: tx,
}
}

View File

@@ -1003,163 +1003,3 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func TestSqlite_GetTakenIPs(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []net.IP{}, takenIPs)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
IP: net.IP{1, 1, 1, 1},
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip1 := net.IP{1, 1, 1, 1}.To16()
assert.Equal(t, []net.IP{ip1}, takenIPs)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
IP: net.IP{2, 2, 2, 2},
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip2 := net.IP{2, 2, 2, 2}.To16()
assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
}
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
peer1 := &nbpeer.Peer{
ID: "peer1",
AccountID: existingAccountID,
DNSLabel: "peer1.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
peer2 := &nbpeer.Peer{
ID: "peer2",
AccountID: existingAccountID,
DNSLabel: "peer2.domain.test",
}
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
func TestSqlite_GetAccountNetwork(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
ip := net.IP{100, 64, 0, 0}.To16()
assert.Equal(t, ip, network.Net.IP)
assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
assert.Equal(t, "", network.Dns)
assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
assert.Equal(t, uint64(0), network.Serial)
}
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
defer store.Close(context.Background())
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
_, err := store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 0, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 1, setupKey.UsedTimes)
err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
require.NoError(t, err)
setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}

View File

@@ -100,13 +100,3 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError() error {
return Errorf(NotFound, "setup key not found")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}

View File

@@ -27,15 +27,6 @@ import (
"github.com/netbirdio/netbird/route"
)
type LockingStrength string
const (
LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
)
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
@@ -50,7 +41,7 @@ type Store interface {
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetUserByUserID(ctx context.Context, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
@@ -69,24 +60,14 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
}
type StoreEngine string

View File

@@ -1,120 +0,0 @@
{
"Accounts": {
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
"CreatedBy": "",
"Domain": "test.com",
"DomainCategory": "private",
"IsDomainPrimaryAccount": true,
"SetupKeys": {
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
"Name": "Default key",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["cfefqs706sqkneg59g2g"],
"UsageLimit": 0,
"Ephemeral": false
},
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"AccountID": "",
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
"Name": "Faulty key with non existing group",
"Type": "reusable",
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
"UpdatedAt": "0001-01-01T00:00:00Z",
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": ["abcd"],
"UsageLimit": 0,
"Ephemeral": false
}
},
"Network": {
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": "",
"Serial": 0
},
"Peers": {},
"Users": {
"edafee4e-63fb-11ec-90d6-0242ac120003": {
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "admin",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": ["cfefqs706sqkneg59g3g"],
"PATs": {},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
},
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
"AccountID": "",
"Role": "user",
"IsServiceUser": false,
"ServiceUserName": "",
"AutoGroups": null,
"PATs": {
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
"UserID": "",
"Name": "",
"HashedToken": "SoMeHaShEdToKeN",
"ExpirationDate": "2023-02-27T00:00:00Z",
"CreatedBy": "user",
"CreatedAt": "2023-01-01T00:00:00Z",
"LastUsed": "2023-02-01T00:00:00Z"
}
},
"Blocked": false,
"LastLogin": "0001-01-01T00:00:00Z"
}
},
"Groups": {
"cfefqs706sqkneg59g4g": {
"ID": "cfefqs706sqkneg59g4g",
"Name": "All",
"Peers": []
},
"cfefqs706sqkneg59g3g": {
"ID": "cfefqs706sqkneg59g3g",
"Name": "AwesomeGroup1",
"Peers": []
},
"cfefqs706sqkneg59g2g": {
"ID": "cfefqs706sqkneg59g2g",
"Name": "AwesomeGroup2",
"Peers": []
}
},
"Rules": null,
"Policies": [],
"Routes": null,
"NameServerGroups": null,
"DNSSettings": null,
"Settings": {
"PeerLoginExpirationEnabled": false,
"PeerLoginExpiration": 86400000000000,
"GroupsPropagationEnabled": false,
"JWTGroupsEnabled": false,
"JWTGroupsClaimName": ""
}
}
},
"InstallationID": ""
}

View File

@@ -89,6 +89,10 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
func (u *User) updateLastLogin(login time.Time) {
u.LastLogin = login
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
@@ -382,7 +386,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}

View File

@@ -58,65 +58,37 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct {
log *log.Entry
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
ctx context.Context
cancel context.CancelFunc
}
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
log: log,
conn: conn,
messages: messages,
ctx: ctx,
cancel: cancel,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
msg.Free()
return
}
select {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.cancel()
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.closed = true
close(cc.messages)
for msg := range cc.messages {
msg.Free()
}
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
@@ -148,8 +120,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
c := &Client{
log: log.WithFields(log.Fields{"relay": serverURL}),
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
connectionURL: serverURL,
authTokenStore: authTokenStore,
@@ -162,13 +134,11 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
},
conns: make(map[string]*connContainer),
}
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
return c
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.log.Infof("connecting to relay server")
c.log.Infof("connecting to relay server: %s", c.connectionURL)
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
@@ -189,7 +159,7 @@ func (c *Client) Connect() error {
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
c.log.Infof("relay connection established")
c.log.Infof("relay connection established with: %s", c.connectionURL)
return nil
}
@@ -211,11 +181,11 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
return nil, ErrConnAlreadyExists
}
c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 100)
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
@@ -259,7 +229,7 @@ func (c *Client) connect() error {
if err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
log.Errorf("failed to close connection: %s", cErr)
}
return err
}
@@ -270,19 +240,19 @@ func (c *Client) connect() error {
func (c *Client) handShake() error {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err)
log.Errorf("failed to marshal auth message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send auth message: %s", err)
log.Errorf("failed to send auth message: %s", err)
return err
}
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(buf)
if err != nil {
c.log.Errorf("failed to read auth response: %s", err)
log.Errorf("failed to read auth response: %s", err)
return err
}
@@ -293,12 +263,12 @@ func (c *Client) handShake() error {
msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType)
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
@@ -315,7 +285,7 @@ func (c *Client) handShake() error {
func (c *Client) readLoop(relayConn net.Conn) {
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
hc := healthcheck.NewReceiver()
go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
var (
@@ -327,7 +297,6 @@ func (c *Client) readLoop(relayConn net.Conn) {
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
@@ -374,7 +343,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte,
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypeClose:
c.log.Debugf("relay connection close by server")
log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
@@ -443,14 +412,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
c.log.Errorf("failed to marshal transport message: %s", err)
log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to write transport message: %s", err)
log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
@@ -464,15 +433,12 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
_ = conn.Close() // ignore the err because the readLoop will handle it
return
case <-c.parentCtx.Done():
err := c.close(true)
if err != nil {
c.log.Errorf("failed to teardown connection: %s", err)
log.Errorf("failed to teardown connection: %s", err)
}
return
}
@@ -498,9 +464,8 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id)
container.close()
delete(c.conns, id)
return nil
}
@@ -513,12 +478,10 @@ func (c *Client) close(gracefullyExit bool) error {
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
c.log.Warn("relay connection was already marked as not running")
return nil
}
c.serviceIsRunning = false
c.log.Infof("closing all peer connections")
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
@@ -526,9 +489,8 @@ func (c *Client) close(gracefullyExit bool) error {
err = c.relayConn.Close()
c.mu.Unlock()
c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed")
c.log.Infof("relay connection closed with: %s", c.connectionURL)
return err
}

View File

@@ -618,87 +618,6 @@ func TestCloseByClient(t *testing.T) {
}
}
func TestCloseNotDrainedChannel(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
// the internal channel buffer size is 2. So we should overflow it
for i := 0; i < 5; i++ {
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
}
// wait for delivery
time.Sleep(1 * time.Second)
err = connBobToAlice.Close()
if err != nil {
t.Errorf("failed to close channel: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:

View File

@@ -3,6 +3,7 @@ package client
import (
"container/list"
"context"
"errors"
"fmt"
"net"
"reflect"
@@ -16,6 +17,8 @@ import (
var (
relayCleanupInterval = 60 * time.Second
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
@@ -89,23 +92,67 @@ func (m *Manager) Serve() error {
}
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
sp := ServerPicker{
TokenStore: m.tokenStore,
PeerID: m.peerID,
totalServers := len(m.serverURLs)
successChan := make(chan *Client, 1)
errChan := make(chan error, len(m.serverURLs))
ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout)
defer cancel()
sem := make(chan struct{}, maxConcurrentServers)
for _, url := range m.serverURLs {
sem <- struct{}{}
go func(url string) {
defer func() { <-sem }()
m.connect(m.ctx, url, successChan, errChan)
}(url)
}
client, err := sp.PickServer(m.ctx, m.serverURLs)
if err != nil {
return err
}
m.relayClient = client
var errCount int
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
for {
select {
case client := <-successChan:
log.Infof("Successfully connected to relay server: %s", client.connectionURL)
m.relayClient = client
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(func() {
m.onServerDisconnected(client.connectionURL)
})
m.startCleanupLoop()
return nil
case err := <-errChan:
errCount++
log.Warnf("Connection attempt failed: %v", err)
if errCount == totalServers {
return errors.New("failed to connect to any relay server: all attempts failed")
}
case <-ctx.Done():
return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
}
func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) {
// TODO: abort the connection if another connection was successful
relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID)
if err := relayClient.Connect(); err != nil {
errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err)
return
}
select {
case successChan <- relayClient:
// This client was the first to connect successfully
default:
if err := relayClient.Close(); err != nil {
log.Debugf("failed to close relay client: %s", err)
}
}
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be

View File

@@ -1,98 +0,0 @@
package client
import (
"context"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/relay/auth/hmac"
)
const (
connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
type connResult struct {
RelayClient *Client
Url string
Err error
}
type ServerPicker struct {
TokenStore *auth.TokenStore
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(urls)
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
go sp.processConnResults(connResultChan, successChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect()
resultChan <- connResult{
RelayClient: relayClient,
Url: url,
Err: err,
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
if hasSuccess {
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
if err := cr.RelayClient.Close(); err != nil {
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
}
continue
}
hasSuccess = true
successChan <- cr
}
close(successChan)
}

View File

@@ -1,31 +0,0 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}

View File

@@ -23,12 +23,15 @@ import (
"github.com/netbirdio/netbird/util"
)
const (
metricsPort = 9090
)
type Config struct {
ListenAddress string
// in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
// it is a domain:port or ip:port
ExposedAddress string
MetricsPort int
LetsencryptEmail string
LetsencryptDataDir string
LetsencryptDomains []string
@@ -77,7 +80,6 @@ func init() {
cobraConfig = &Config{}
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
@@ -114,7 +116,7 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err)
}
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
metricsServer, err := metrics.NewServer(metricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)

View File

@@ -3,12 +3,10 @@ package healthcheck
import (
"context"
"time"
log "github.com/sirupsen/logrus"
)
var (
heartbeatTimeout = healthCheckInterval + 10*time.Second
heartbeatTimeout = healthCheckInterval + 3*time.Second
)
// Receiver is a healthcheck receiver
@@ -16,26 +14,23 @@ var (
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
// The heartbeat timeout is a bit longer than the sender's healthcheck interval
type Receiver struct {
OnTimeout chan struct{}
log *log.Entry
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
attemptThreshold int
OnTimeout chan struct{}
ctx context.Context
ctxCancel context.CancelFunc
heartbeat chan struct{}
alive bool
}
// NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver(log *log.Entry) *Receiver {
func NewReceiver() *Receiver {
ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{
OnTimeout: make(chan struct{}, 1),
log: log,
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
OnTimeout: make(chan struct{}, 1),
ctx: ctx,
ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1),
}
go r.waitForHealthcheck()
@@ -61,23 +56,16 @@ func (r *Receiver) waitForHealthcheck() {
defer r.ctxCancel()
defer close(r.OnTimeout)
failureCounter := 0
for {
select {
case <-r.heartbeat:
r.alive = true
failureCounter = 0
case <-ticker.C:
if r.alive {
r.alive = false
continue
}
failureCounter++
if failureCounter < r.attemptThreshold {
r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
continue
}
r.notifyTimeout()
return
case <-r.ctx.Done():

View File

@@ -1,18 +1,13 @@
package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second
r := NewReceiver(log.WithContext(context.Background()))
r := NewReceiver()
select {
case <-r.OnTimeout:
@@ -24,7 +19,7 @@ func TestNewReceiver(t *testing.T) {
func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second
r := NewReceiver(log.WithContext(context.Background()))
r := NewReceiver()
select {
case <-r.OnTimeout:
@@ -35,7 +30,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second
r := NewReceiver(log.WithContext(context.Background()))
r := NewReceiver()
r.Heartbeat()
@@ -45,53 +40,3 @@ func TestNewReceiverAck(t *testing.T) {
case <-time.After(3 * time.Second):
}
}
func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
defer func() {
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce {
receiver.Heartbeat()
t.Logf("reset counter once")
}
select {
case <-receiver.OnTimeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}

View File

@@ -2,21 +2,12 @@ package healthcheck
import (
"context"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
var (
healthCheckInterval = 25 * time.Second
healthCheckTimeout = 20 * time.Second
healthCheckTimeout = 5 * time.Second
)
// Sender is a healthcheck sender
@@ -24,25 +15,20 @@ var (
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled
type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{}
ack chan struct{}
alive bool
attemptThreshold int
ack chan struct{}
}
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
func NewSender() *Sender {
hc := &Sender{
log: log,
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(),
HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1),
ack: make(chan struct{}, 1),
}
return hc
@@ -60,51 +46,23 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval)
defer ticker.Stop()
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
defer timeoutTicker.Stop()
timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout)
defer timeoutTimer.Stop()
defer close(hc.HealthCheck)
defer close(hc.Timeout)
failureCounter := 0
for {
select {
case <-ticker.C:
hc.HealthCheck <- struct{}{}
case <-timeoutTicker.C:
if hc.alive {
hc.alive = false
continue
}
failureCounter++
if failureCounter < hc.attemptThreshold {
hc.log.Warnf("Health check failed attempt %d.", failureCounter)
continue
}
case <-timeoutTimer.C:
hc.Timeout <- struct{}{}
return
case <-hc.ack:
failureCounter = 0
hc.alive = true
timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout)
case <-ctx.Done():
return
}
}
}
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,12 +2,9 @@ package healthcheck
import (
"context"
"fmt"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
func TestMain(m *testing.M) {
@@ -21,7 +18,7 @@ func TestMain(m *testing.M) {
func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(log.WithContext(ctx))
hc := NewSender()
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -41,7 +38,7 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(log.WithContext(ctx))
hc := NewSender()
go hc.StartHealthCheck(ctx)
select {
@@ -53,7 +50,7 @@ func TestNewHealthFailed(t *testing.T) {
func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
hc := NewSender(log.WithContext(ctx))
hc := NewSender()
go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond)
@@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := NewSender(log.WithContext(ctx))
hc := NewSender()
go hc.StartHealthCheck(ctx)
iterations := 0
@@ -104,102 +101,3 @@ func TestTimeoutReset(t *testing.T) {
t.Fatalf("is not exited")
}
}
func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
testsCases := []struct {
name string
threshold int
resetCounterOnce bool
}{
{"Default attempt threshold", defaultAttemptThreshold, false},
{"Custom attempt threshold", 3, false},
{"Should reset threshold once", 2, true},
}
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval
originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond
defer func() {
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx)
go func() {
responded := false
for {
select {
case <-ctx.Done():
return
case _, ok := <-sender.HealthCheck:
if !ok {
return
}
if tc.resetCounterOnce && !responded {
responded = true
sender.OnHCResponse()
}
}
}
}()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
select {
case <-sender.Timeout:
if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout)
}
case <-time.After(testTimeout):
if tc.resetCounterOnce {
return
}
t.Fatalf("should have timed out before %s", testTimeout)
}
})
}
}
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -49,7 +49,7 @@ func (p *Peer) Work() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hc := healthcheck.NewSender(p.log)
hc := healthcheck.NewSender()
go hc.StartHealthCheck(ctx)
go p.handleHealthcheckEvents(ctx, hc)
@@ -115,7 +115,6 @@ func (p *Peer) Write(b []byte) (int, error) {
// connection.
func (p *Peer) CloseGracefully(ctx context.Context) {
p.connMu.Lock()
defer p.connMu.Unlock()
err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
if err != nil {
p.log.Errorf("failed to send close message to peer: %s", p.String())
@@ -125,15 +124,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
}
func (p *Peer) Close() {
p.connMu.Lock()
defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
}
// String returns the peer ID
@@ -175,7 +167,6 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send
if err != nil {
p.log.Errorf("failed to close connection to peer: %s", err)
}
p.log.Info("peer connection closed due healthcheck timeout")
return
case <-ctx.Done():
return

View File

@@ -19,14 +19,10 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
// todo: consider to close peer conn if the peer already exists
func (s *Store) AddPeer(peer *Peer) {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.String()]
if ok {
odlPeer.Close()
}
s.peers[peer.String()] = peer
}

View File

@@ -2,57 +2,13 @@ package server
import (
"context"
"net"
"testing"
"time"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/relay/metrics"
)
type mockConn struct {
}
func (m mockConn) Read(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Write(b []byte) (n int, err error) {
//TODO implement me
panic("implement me")
}
func (m mockConn) Close() error {
return nil
}
func (m mockConn) LocalAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) RemoteAddr() net.Addr {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (m mockConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func TestStore_DeletePeer(t *testing.T) {
s := NewStore()
@@ -71,9 +27,8 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) {
m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
conn := &mockConn{}
p1 := NewPeer(m, []byte("peer_id"), conn, nil)
p2 := NewPeer(m, []byte("peer_id"), conn, nil)
p1 := NewPeer(m, []byte("peer_id"), nil, nil)
p2 := NewPeer(m, []byte("peer_id"), nil, nil)
s.AddPeer(p1)
s.AddPeer(p2)

View File

@@ -300,13 +300,11 @@ install_netbird() {
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service
if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
if ! ${SUDO} netbird service install 2>&1; then
echo "NetBird service has already been loaded"
fi
if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
fi
if ! ${SUDO} netbird service install 2>&1; then
echo "NetBird service has already been loaded"
fi
if ! ${SUDO} netbird service start 2>&1; then
echo "NetBird service has already been started"
fi

View File

@@ -29,9 +29,12 @@ import (
"google.golang.org/grpc/keepalive"
)
const (
metricsPort = 9090
)
var (
signalPort int
metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
@@ -285,7 +288,6 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")

View File

@@ -82,11 +82,8 @@ func (registry *Registry) Register(peer *Peer) {
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer)
return
}
log.Debugf("peer registered [%s]", peer.Id)
registry.metrics.ActivePeers.Add(context.Background(), 1)
// record time as milliseconds
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
@@ -108,8 +105,8 @@ func (registry *Registry) Deregister(peer *Peer) {
peer.Id, pp.StreamID, peer.StreamID)
return
}
registry.metrics.ActivePeers.Add(context.Background(), -1)
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
}
log.Debugf("peer deregistered [%s]", peer.Id)
registry.metrics.Deregistrations.Add(context.Background(), 1)
}

View File

@@ -133,6 +133,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
s.metrics.ActivePeers.Add(stream.Context(), 1)
return p, nil
} else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
@@ -149,6 +151,7 @@ func (s *Server) DeregisterPeer(p *peer.Peer) {
s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
s.metrics.ActivePeers.Add(context.Background(), -1)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {

View File

@@ -10,30 +10,51 @@ import (
log "github.com/sirupsen/logrus"
)
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
err = EnforcePermission(file)
if err != nil {
return err
}
return writeJson(file, obj, configDir, configFileName)
}
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
return writeJson(file, obj, configDir, configFileName)
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -75,46 +96,6 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
return err
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
}
tempFileName := tempFile.Name()
// closing file ops as windows doesn't allow to move it
err = tempFile.Close()
if err != nil {
return err
}
defer func() {
_, err = os.Stat(tempFileName)
if err == nil {
os.Remove(tempFileName)
}
}()
err = os.WriteFile(tempFileName, bs, 0600)
if err != nil {
return err
}
err = os.Rename(tempFileName, file)
if err != nil {
return err
}
return nil
}
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
@@ -191,9 +172,5 @@ func prepareConfigFileDir(file string) (string, string, error) {
}
err := os.MkdirAll(configDir, 0750)
if err != nil {
return "", "", err
}
return configDir, configFileName, err
}

View File

@@ -5,7 +5,6 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -13,8 +12,6 @@ import (
"github.com/netbirdio/netbird/formatter"
)
const defaultLogSize = 5
// InitLog parses and sets log-level input
func InitLog(logLevel string, logPath string) error {
level, err := log.ParseLevel(logLevel)
@@ -22,14 +19,13 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err
}
customOutputs := []string{"console", "syslog"}
customOutputs := []string{"console", "syslog"};
if logPath != "" && !slices.Contains(customOutputs, logPath) {
maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath),
MaxSize: maxLogSize, // MB
MaxSize: 5, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
@@ -50,18 +46,3 @@ func InitLog(logLevel string, logPath string) error {
log.SetLevel(level)
return nil
}
func getLogMaxSize() int {
if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
size, err := strconv.ParseInt(sizeVar, 10, 64)
if err != nil {
log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
return defaultLogSize
}
log.Infof("Setting log file max size to %d MB", size)
return int(size)
}
return defaultLogSize
}

View File

@@ -1,7 +0,0 @@
//go:build !windows
package util
func EnforcePermission(dirPath string) error {
return nil
}

View File

@@ -1,86 +0,0 @@
package util
import (
"path/filepath"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
securityFlags = windows.OWNER_SECURITY_INFORMATION |
windows.GROUP_SECURITY_INFORMATION |
windows.DACL_SECURITY_INFORMATION |
windows.PROTECTED_DACL_SECURITY_INFORMATION
)
func EnforcePermission(file string) error {
dirPath := filepath.Dir(file)
user, group, err := sids()
if err != nil {
return err
}
adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
if err != nil {
return err
}
explicitAccess := []windows.EXPLICIT_ACCESS{
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_USER,
TrusteeValue: windows.TrusteeValueFromSID(user),
},
},
{
AccessPermissions: windows.GENERIC_ALL,
AccessMode: windows.SET_ACCESS,
Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
Trustee: windows.TRUSTEE{
MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
TrusteeForm: windows.TRUSTEE_IS_SID,
TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
},
},
}
dacl, err := windows.ACLFromEntries(explicitAccess, nil)
if err != nil {
return err
}
return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
}
func sids() (*windows.SID, *windows.SID, error) {
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
if err != nil {
return nil, nil, err
}
defer func() {
if err := token.Close(); err != nil {
log.Errorf("failed to close process token: %v", err)
}
}()
tu, err := token.GetTokenUser()
if err != nil {
return nil, nil, err
}
pg, err := token.GetTokenPrimaryGroup()
if err != nil {
return nil, nil, err
}
return tu.User.Sid, pg.PrimaryGroup, nil
}