Compare commits

...

41 Commits

Author SHA1 Message Date
bcmmbaga
bf4767211a Merge branch 'refs/heads/feature/optimize_sqlite_save' into deploy/posture-check-sqlite 2024-04-18 11:05:06 +03:00
Misha Bragin
515ce9e3af Update management/server/sqlite_store.go 2024-04-17 20:55:32 +02:00
Misha Bragin
89383b7f01 Update management/server/sqlite_store.go 2024-04-17 20:55:01 +02:00
Misha Bragin
db34162733 Update management/server/sqlite_store.go 2024-04-17 20:54:14 +02:00
Misha Bragin
bd761e2177 Update management/server/sqlite_store.go 2024-04-17 20:53:32 +02:00
Misha Bragin
4e1b95a4c6 Update management/server/sqlite_store.go 2024-04-17 20:53:24 +02:00
Misha Bragin
05993af7bf Update management/server/sqlite_store.go 2024-04-17 20:53:11 +02:00
braginini
9d1cb00570 Fix setup keys test 2024-04-17 20:27:55 +02:00
braginini
543731df45 Fix setup keys test 2024-04-17 19:58:24 +02:00
braginini
e6628ec231 Fix setup keys 2024-04-17 19:48:09 +02:00
braginini
41d4dd2aff reduce log level of scheduler to trace 2024-04-17 19:34:59 +02:00
braginini
30bed57711 Fix account deletion 2024-04-17 19:12:53 +02:00
braginini
6960b68322 Add pats to test save account 2024-04-17 19:07:17 +02:00
braginini
3b3aa18148 Store setup keys and ns groups in a batch 2024-04-17 18:32:13 +02:00
braginini
93045f3e3a Fix rand lint issue 2024-04-17 18:07:02 +02:00
braginini
fd3c1dea8e Add save large account test 2024-04-17 18:02:10 +02:00
braginini
48aff7a26e Fix test compilation errors 2024-04-17 17:39:28 +02:00
braginini
83dfe8e3a3 Fix test compilation errors 2024-04-17 17:27:23 +02:00
braginini
38e10af2d9 Add accountID reference 2024-04-17 17:16:56 +02:00
braginini
99854a126a Add comments 2024-04-17 17:08:01 +02:00
braginini
a75f982fcd Copy account when storing to avoid reference issues 2024-04-17 17:03:21 +02:00
bcmmbaga
7745ed7eb0 Merge branch 'refs/heads/main' into add-process-posture-check 2024-04-17 16:37:29 +03:00
braginini
e7a6483912 Optimize all other objects storing in SQLite 2024-04-17 12:35:41 +02:00
braginini
30ede299b8 Optimize peer storing in SQLite 2024-04-17 11:50:33 +02:00
bcmmbaga
6bfd1b2886 fix merge conflicts 2024-04-15 16:18:41 +03:00
bcmmbaga
8aa32a2da5 Merge branch 'refs/heads/main' into add-process-posture-check
# Conflicts:
#	management/server/peer.go
2024-04-15 16:14:21 +03:00
Bethuel Mmbaga
c6ab215d9d Extend management to sync meta and posture checks with peer (#1727)
* Add method to retrieve peer's applied posture checks

* Add posture checks in server response and update proto messages

* Refactor

* Extends peer metadata synchronization through SyncRequest and propagate posture changes on syncResponse

* Remove account lock

* Pass system info on sync

* Fix tests

* Refactor

* resolve merge

* Evaluate process check on client (#1749)

* implement  server and client sync peer meta alongside mocks

* wip: add check file and process

* Add files to peer metadata for process check

* wip: update peer meta on first sync

* Add files to peer's metadata

* Evaluate process check using files from peer metadata

* Fix panic and append windows path to files

* Fix check network address and files equality

* Evaluate active process on darwin

* Evaluate active process on linux

* Skip processing processes if no paths are set

* Return network map on peer meta-sync and update account peer's

* Update client network map on meta sync

* Get system info with applied checks

* Add windows package

* Remove a network map from sync meta-response

* Update checks proto message

* Keep client checks state and sync meta on checks change

* Evaluate a running process

* skip build for android and ios

* skip check file and process for android and ios

* bump gopsutil version

* fix tests

* move process check to separate os file

* refactor

* evaluate info with checks on receiving management events

* skip meta-update for an old client with no meta-sync support

* Check if peer meta is empty without reflection
2024-04-15 16:00:57 +03:00
bcmmbaga
36582d13aa Merge branch 'refs/heads/main' into add-process-posture-check 2024-04-10 17:58:46 +03:00
bcmmbaga
2727680123 Merge branch 'main' into add-process-posture-check 2024-03-21 21:30:40 +03:00
bcmmbaga
9dcaa51b68 Merge branch 'main' into add-process-posture-check 2024-03-18 18:41:38 +03:00
Bethuel Mmbaga
180f5a122e Refactor posture check validations (#1705)
* Add posture checks validation

* Refactor code to incorporate posture checks validation directly into management.

* Add posture checks validation for geolocation, OS version, network, process, and NB-version

* Fix tests
2024-03-14 20:16:50 +00:00
bcmmbaga
90ab2f7c89 Fix linters 2024-03-14 16:06:50 +03:00
bcmmbaga
4ab993c933 Fix tests 2024-03-14 15:52:15 +03:00
bcmmbaga
1a5d59be1d Refactor 2024-03-14 14:35:21 +03:00
bcmmbaga
9db450d599 Add single Unix/Windows path check in process tests 2024-03-14 14:32:55 +03:00
bcmmbaga
cc60df7805 Allow set of single unix or windows path check 2024-03-14 14:32:40 +03:00
bcmmbaga
60f9f08ecb fix tests 2024-03-13 11:02:47 +03:00
bcmmbaga
41348bb39b Add process validation for peer metadata 2024-03-12 19:24:08 +03:00
bcmmbaga
e66e39cc70 Extend peer metadata with processes 2024-03-12 19:23:57 +03:00
bcmmbaga
9f41a1f20f add process posture check to posture checks handlers 2024-03-12 15:20:00 +03:00
bcmmbaga
5f0eec0add wip: add process check posture 2024-03-12 15:19:22 +03:00
57 changed files with 2574 additions and 867 deletions

View File

@@ -237,7 +237,10 @@ func runClient(
return wrapErr(err)
}
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
checks := loginResp.GetChecks()
engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig,
mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
err = engine.Start()
if err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"reflect"
"runtime"
"slices"
"strings"
"sync"
"time"
@@ -27,6 +28,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
@@ -138,6 +140,9 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
}
// Peer is an instance of the Connection Peer
@@ -155,6 +160,7 @@ func NewEngine(
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
return NewEngineWithProbes(
ctx,
@@ -168,6 +174,7 @@ func NewEngine(
nil,
nil,
nil,
checks,
)
}
@@ -184,6 +191,7 @@ func NewEngineWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
checks []*mgmProto.Checks,
) *Engine {
return &Engine{
ctx: ctx,
@@ -204,6 +212,7 @@ func NewEngineWithProbes(
signalProbe: signalProbe,
relayProbe: relayProbe,
wgProbe: wgProbe,
checks: checks,
}
}
@@ -486,6 +495,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// todo update signal
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
@@ -493,7 +506,27 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
}
return nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update
if isChecksEqual(e.checks, checks) {
return nil
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
@@ -583,7 +616,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
go func() {
err := e.mgmClient.Sync(e.handleSync)
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
err = e.mgmClient.Sync(info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1151,7 +1190,8 @@ func (e *Engine) close() {
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
netMap, err := e.mgmClient.GetNetworkMap()
info := system.GetInfo(e.ctx)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, err
}
@@ -1329,3 +1369,10 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

View File

@@ -76,7 +76,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -210,7 +210,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -391,7 +391,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -406,7 +406,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -564,7 +564,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -733,7 +733,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -1002,7 +1002,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
}
func startSignal() (*grpc.Server, string, error) {

View File

@@ -8,6 +8,7 @@ import (
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/version"
)
@@ -30,6 +31,12 @@ type Environment struct {
Platform string
}
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// Info is an object that contains machine information
// Most of the code is taken from https://github.com/matishsiao/goInfo
type Info struct {
@@ -48,6 +55,7 @@ type Info struct {
SystemProductName string
SystemManufacturer string
Environment Environment
Files []File
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@@ -129,3 +137,21 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
processCheckPaths := make([]string, 0)
for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
if err != nil {
return nil, err
}
info := GetInfo(ctx)
info.Files = files
return info, nil
}

View File

@@ -36,6 +36,11 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
func uname() []string {
res := run("/system/bin/uname", "-a")
return strings.Split(res, " ")

View File

@@ -25,6 +25,11 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil
}
// extractOsVersion extracts operating system version from context or returns the default
func extractOsVersion(ctx context.Context, defaultName string) string {
v, ok := ctx.Value(OsVersionCtxKey).(string)

58
client/system/process.go Normal file
View File

@@ -0,0 +1,58 @@
//go:build windows || (linux && !android) || (darwin && !ios)
package system
import (
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processes, err := process.Processes()
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, p := range processes {
path, _ := p.Exe()
if path != "" {
processMap[path] = true
}
}
uniqueProcesses := make([]string, 0, len(processMap))
for p := range processMap {
uniqueProcesses = append(uniqueProcesses, p)
}
return uniqueProcesses, nil
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
if err != nil {
return nil, err
}
for i, path := range paths {
file := File{Path: path}
_, err := os.Stat(path)
file.Exist = !os.IsNotExist(err)
file.ProcessIsRunning = slices.Contains(runningProcesses, path)
files[i] = file
}
return files, nil
}

14
go.mod
View File

@@ -22,7 +22,7 @@ require (
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/crypto v0.18.0
golang.org/x/sys v0.16.0
golang.org/x/sys v0.18.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -44,7 +44,7 @@ require (
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.9
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
@@ -70,10 +70,11 @@ require (
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.14.0
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.3
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/things-go/go-socks5 v0.0.4
github.com/yusufpapurcu/wmi v1.2.3
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.0.2
go.opentelemetry.io/otel v1.11.1
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
@@ -131,6 +132,7 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
@@ -142,12 +144,16 @@ require (
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/tklauser/go-sysconf v0.3.13 // indirect
github.com/tklauser/numcpus v0.7.0 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/yuin/goldmark v1.4.13 // indirect
go.opencensus.io v0.24.0 // indirect

32
go.sum
View File

@@ -247,9 +247,11 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
@@ -348,6 +350,8 @@ github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdA
github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU=
github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ=
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls=
github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60=
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
@@ -449,6 +453,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=
@@ -485,6 +491,12 @@ github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shirou/gopsutil/v3 v3.24.3 h1:eoUGJSmdfLzJ3mxIhmOAhgKEKgQkeOwKpz1NbhVnuPE=
github.com/shirou/gopsutil/v3 v3.24.3/go.mod h1:JpND7O217xa72ewWz9zN2eIIkPWsDN/3pl0H8Qt0uwg=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
@@ -514,6 +526,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -524,10 +537,17 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0=
github.com/things-go/go-socks5 v0.0.4/go.mod h1:sh4K6WHrmHZpjxLTCHyYtXYH8OUuD+yZun41NomR1IQ=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
@@ -544,8 +564,8 @@ github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1
github.com/yuin/goldmark v1.3.8/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc=
github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
@@ -740,6 +760,7 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -761,8 +782,9 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View File

@@ -3,19 +3,21 @@ package client
import (
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Client interface {
io.Closer
Sync(msgHandler func(msg *proto.SyncResponse) error) error
Sync(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap() (*proto.NetworkMap, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
}

View File

@@ -255,7 +255,7 @@ func TestClient_Sync(t *testing.T) {
ch := make(chan *mgmtProto.SyncResponse, 1)
go func() {
err = client.Sync(func(msg *mgmtProto.SyncResponse) error {
err = client.Sync(info, func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})

View File

@@ -113,7 +113,7 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
func (c *GrpcClient) Sync(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
backOff := defaultBackoff(c.ctx)
operation := func() error {
@@ -135,7 +135,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -177,7 +177,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
@@ -186,7 +186,7 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
@@ -219,8 +219,8 @@ func (c *GrpcClient) GetNetworkMap() (*proto.NetworkMap, error) {
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
@@ -430,6 +430,35 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
return flowInfoResp, nil
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() {
return fmt.Errorf("no connection to management")
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return err
}
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: syncMetaReq,
})
return err
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
@@ -463,6 +492,15 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
})
}
files := make([]*proto.File, 0, len(info.Files))
for _, file := range info.Files {
files = append(files, &proto.File{
Path: file.Path,
Exist: file.Exist,
ProcessIsRunning: file.ProcessIsRunning,
})
}
return &proto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
@@ -482,5 +520,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
Cloud: info.Environment.Cloud,
Platform: info.Environment.Platform,
},
Files: files,
}
}

View File

@@ -9,12 +9,13 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
SyncFunc func(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
}
func (m *MockClient) IsHealthy() bool {
@@ -28,11 +29,11 @@ func (m *MockClient) Close() error {
return m.CloseFunc()
}
func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
func (m *MockClient) Sync(sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil {
return nil
}
return m.SyncFunc(msgHandler)
return m.SyncFunc(sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
@@ -71,6 +72,13 @@ func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap() (*proto.NetworkMap, error) {
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
}

File diff suppressed because it is too large Load Diff

View File

@@ -38,6 +38,12 @@ service ManagementService {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
}
message EncryptedMessage {
@@ -50,7 +56,10 @@ message EncryptedMessage {
int32 version = 3;
}
message SyncRequest {}
message SyncRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
// SyncResponse represents a state that should be applied to the local peer (e.g. Wiretrustee servers config as well as local peer and remote peers configs)
message SyncResponse {
@@ -69,6 +78,14 @@ message SyncResponse {
bool remotePeersIsEmpty = 4;
NetworkMap NetworkMap = 5;
// Posture checks to be evaluated by client
repeated Checks Checks = 6;
}
message SyncMetaRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
message LoginRequest {
@@ -82,6 +99,7 @@ message LoginRequest {
PeerKeys peerKeys = 4;
}
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
// This message is sent on Login or register requests, or when a key rotation has to happen.
message PeerKeys {
@@ -100,6 +118,16 @@ message Environment {
string platform = 2;
}
// File represents a file on the system.
message File {
// path is the path to the file.
string path = 1;
// exist indicate whether the file exists.
bool exist = 2;
// processIsRunning indicates whether the file is a running process or not.
bool processIsRunning = 3;
}
// PeerSystemMeta is machine meta data like OS and version.
message PeerSystemMeta {
string hostname = 1;
@@ -117,6 +145,7 @@ message PeerSystemMeta {
string sysProductName = 13;
string sysManufacturer = 14;
Environment environment = 15;
repeated File files = 16;
}
message LoginResponse {
@@ -124,6 +153,8 @@ message LoginResponse {
WiretrusteeConfig wiretrusteeConfig = 1;
// Peer local config
PeerConfig peerConfig = 2;
// Posture checks to be evaluated by client
repeated Checks Checks = 3;
}
message ServerKeyResponse {
@@ -371,3 +402,7 @@ message NetworkAddress {
string netIP = 1;
string mac = 2;
}
message Checks {
repeated string Files= 1;
}

View File

@@ -43,6 +43,11 @@ type ManagementServiceClient interface {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
}
type managementServiceClient struct {
@@ -130,6 +135,15 @@ func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context,
return out, nil
}
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
@@ -159,6 +173,11 @@ type ManagementServiceServer interface {
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
mustEmbedUnimplementedManagementServiceServer()
}
@@ -184,6 +203,9 @@ func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.C
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -308,6 +330,24 @@ func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx co
return interceptor(ctx, in, info, handler)
}
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/SyncMeta",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -335,6 +375,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetPKCEAuthorizationFlow",
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
},
{
MethodName: "SyncMeta",
Handler: _ManagementService_SyncMeta_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -114,6 +114,7 @@ type AccountManager interface {
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
@@ -1473,7 +1474,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
@@ -1862,9 +1863,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
AccountID: account.Id,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
@@ -1908,7 +1910,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewOwnerUser(userID)
users[userID] = NewOwnerUser(userID, accountID)
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}

View File

@@ -134,7 +134,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
if syncReq.GetMeta() == nil {
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{
WireGuardPubKey: peerKey.String(),
Meta: extractPeerMeta(syncReq.GetMeta()),
})
if err != nil {
return mapError(err)
}
@@ -255,14 +262,18 @@ func mapError(err error) error {
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
osVersion := loginReq.GetMeta().GetOSVersion()
if osVersion == "" {
osVersion = loginReq.GetMeta().GetCore()
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil {
return nbpeer.PeerSystemMeta{}
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses()))
for _, addr := range loginReq.GetMeta().GetNetworkAddresses() {
osVersion := meta.GetOSVersion()
if osVersion == "" {
osVersion = meta.GetCore()
}
networkAddresses := make([]nbpeer.NetworkAddress, 0, len(meta.GetNetworkAddresses()))
for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
@@ -274,24 +285,34 @@ func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
})
}
files := make([]nbpeer.File, 0, len(meta.GetFiles()))
for _, file := range meta.GetFiles() {
files = append(files, nbpeer.File{
Path: file.GetPath(),
Exist: file.GetExist(),
ProcessIsRunning: file.GetProcessIsRunning(),
})
}
return nbpeer.PeerSystemMeta{
Hostname: loginReq.GetMeta().GetHostname(),
GoOS: loginReq.GetMeta().GetGoOS(),
Kernel: loginReq.GetMeta().GetKernel(),
Platform: loginReq.GetMeta().GetPlatform(),
OS: loginReq.GetMeta().GetOS(),
Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(),
Kernel: meta.GetKernel(),
Platform: meta.GetPlatform(),
OS: meta.GetOS(),
OSVersion: osVersion,
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
UIVersion: loginReq.GetMeta().GetUiVersion(),
KernelVersion: loginReq.GetMeta().GetKernelVersion(),
WtVersion: meta.GetWiretrusteeVersion(),
UIVersion: meta.GetUiVersion(),
KernelVersion: meta.GetKernelVersion(),
NetworkAddresses: networkAddresses,
SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(),
SystemProductName: loginReq.GetMeta().GetSysProductName(),
SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(),
SystemSerialNumber: meta.GetSysSerialNumber(),
SystemProductName: meta.GetSysProductName(),
SystemManufacturer: meta.GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(),
Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(),
Cloud: meta.GetEnvironment().GetCloud(),
Platform: meta.GetEnvironment().GetPlatform(),
},
Files: files,
}
}
@@ -366,7 +387,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq),
Meta: extractPeerMeta(loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
@@ -386,6 +407,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toPeerChecks(s.accountManager, peerKey.String()),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
@@ -482,7 +504,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
func toSyncResponse(accountManager AccountManager, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -513,6 +535,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toPeerChecks(accountManager, peer.Key),
}
}
@@ -531,7 +554,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
plainResp := toSyncResponse(s.accountManager, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@@ -648,3 +671,62 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
Body: encryptedResp,
}, nil
}
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(req, syncMetaReq)
if err != nil {
return nil, err
}
if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
return nil, msg
}
_, _, err = s.accountManager.SyncPeer(PeerSync{
WireGuardPubKey: peerKey.String(),
Meta: extractPeerMeta(syncMetaReq.GetMeta()),
UpdateAccountPeers: true,
})
if err != nil {
return nil, mapError(err)
}
return &proto.Empty{}, nil
}
// toPeerChecks returns posture checks for the peer that needs to be evaluated on the client side.
func toPeerChecks(accountManager AccountManager, peerKey string) []*proto.Checks {
postureChecks, err := accountManager.GetPeerAppliedPostureChecks(peerKey)
if err != nil {
log.Errorf("failed getting peer's: %s posture checks: %v", peerKey, err)
return nil
}
protoChecks := make([]*proto.Checks, 0)
for _, postureCheck := range postureChecks {
protoCheck := &proto.Checks{}
if check := postureCheck.Checks.ProcessCheck; check != nil {
for _, process := range check.Processes {
if process.Path != "" {
protoCheck.Files = append(protoCheck.Files, process.Path)
}
if process.WindowsPath != "" {
protoCheck.Files = append(protoCheck.Files, process.WindowsPath)
}
}
}
protoChecks = append(protoChecks, protoCheck)
}
return protoChecks
}

View File

@@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }

View File

@@ -812,6 +812,8 @@ components:
$ref: '#/components/schemas/GeoLocationCheck'
peer_network_range_check:
$ref: '#/components/schemas/PeerNetworkRangeCheck'
process_check:
$ref: '#/components/schemas/ProcessCheck'
NBVersionCheck:
description: Posture check for the version of NetBird
type: object
@@ -900,6 +902,28 @@ components:
required:
- ranges
- action
ProcessCheck:
description: Posture Check for binaries exist and are running in the peers system
type: object
properties:
processes:
type: array
items:
$ref: '#/components/schemas/Process'
required:
- processes
Process:
description: Describes the operational activity within a peer's system.
type: object
properties:
path:
description: Path to the process executable file in a Unix-like operating system
type: string
example: "/usr/local/bin/netbird"
windows_path:
description: Path to the process executable file in a Windows operating system
type: string
example: "C:\ProgramData\NetBird\netbird.exe"
Location:
description: Describe geographical location information
type: object

View File

@@ -225,6 +225,9 @@ type Checks struct {
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
// ProcessCheck Posture Check for binaries exist and are running in the peers system
ProcessCheck *ProcessCheck `json:"process_check,omitempty"`
}
// City Describe city geographical location information
@@ -940,6 +943,20 @@ type PostureCheckUpdate struct {
Name string `json:"name"`
}
// Process Describes the operational activity within a peer's system.
type Process struct {
// Path Path to the process executable file in a Unix-like operating system
Path *string `json:"path,omitempty"`
// WindowsPath Path to the process executable file in a Windows operating system
WindowsPath *string `json:"windows_path,omitempty"`
}
// ProcessCheck Posture Check for binaries exist and are running in the peers system
type ProcessCheck struct {
Processes []Process `json:"processes"`
}
// Route defines model for Route.
type Route struct {
// Description Route description

View File

@@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user"),
testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"),
},
DNSSettings: baseExistingDNSSettings,
}

View File

@@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) {
},
}
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, adminUser, events...)

View File

@@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{

View File

@@ -2,6 +2,7 @@ package http
import (
"net/http"
"regexp"
"github.com/gorilla/mux"
@@ -13,6 +14,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// GeolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct {
accountManager server.AccountManager
@@ -73,8 +78,8 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}

View File

@@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
@@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {
@@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
p := initGroupTestData(adminUser)
for _, tc := range tt {

View File

@@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
return nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",

View File

@@ -4,8 +4,6 @@ import (
"encoding/json"
"net/http"
"net/netip"
"regexp"
"slices"
"github.com/gorilla/mux"
"github.com/rs/xid"
@@ -19,10 +17,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct {
accountManager server.AccountManager
@@ -165,19 +159,16 @@ func (p *PostureChecksHandler) savePostureChecks(
user *server.User,
postureChecksID string,
) {
var (
err error
req api.PostureCheckUpdate
)
var req api.PostureCheckUpdate
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := validatePostureChecksUpdate(req)
if err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
@@ -206,8 +197,8 @@ func (p *PostureChecksHandler) savePostureChecks(
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
@@ -221,6 +212,10 @@ func (p *PostureChecksHandler) savePostureChecks(
}
}
if processCheck := req.Checks.ProcessCheck; processCheck != nil {
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
}
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
util.WriteError(err, w)
return
@@ -229,72 +224,6 @@ func (p *PostureChecksHandler) savePostureChecks(
util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
}
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
if req.Name == "" {
return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty")
}
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
}
if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" {
return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty")
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil &&
osVersionCheck.Linux == nil && osVersionCheck.Windows == nil
emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" ||
osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" ||
osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" ||
osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" ||
osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return status.Errorf(status.InvalidArgument,
"minimum version for at least one OS in the OS version check shouldn't be empty")
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if geoLocationCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty")
}
allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny}
if !slices.Contains(allowedActions, geoLocationCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value")
}
if len(geoLocationCheck.Locations) == 0 {
return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty")
}
for _, loc := range geoLocationCheck.Locations {
if loc.CountryCode == "" {
return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty")
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)")
}
}
}
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
if peerNetworkRangeCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
}
if len(peerNetworkRangeCheck.Ranges) == 0 {
return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
}
}
return nil
}
func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks
@@ -322,6 +251,10 @@ func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
}
if postureChecks.Checks.ProcessCheck != nil {
checks.ProcessCheck = toProcessCheckResponse(postureChecks.Checks.ProcessCheck)
}
return &api.PostureCheck{
Id: postureChecks.ID,
Name: postureChecks.Name,
@@ -332,11 +265,10 @@ func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
func toGeoLocationCheckResponse(geoLocationCheck *posture.GeoLocationCheck) *api.GeoLocationCheck {
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
for _, loc := range geoLocationCheck.Locations {
l := loc // make G601 happy
for i, loc := range geoLocationCheck.Locations {
var cityName *string
if loc.CityName != "" {
cityName = &l.CityName
cityName = &geoLocationCheck.Locations[i].CityName
}
locations = append(locations, api.Location{
CityName: cityName,
@@ -396,3 +328,36 @@ func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNet
Action: string(check.Action),
}, nil
}
func toProcessCheckResponse(check *posture.ProcessCheck) *api.ProcessCheck {
processes := make([]api.Process, 0, len(check.Processes))
for i := range check.Processes {
processes = append(processes, api.Process{
Path: &check.Processes[i].Path,
WindowsPath: &check.Processes[i].WindowsPath,
})
}
return &api.ProcessCheck{
Processes: processes,
}
}
func toProcessCheck(check *api.ProcessCheck) *posture.ProcessCheck {
processes := make([]posture.Process, 0, len(check.Processes))
for _, process := range check.Processes {
var p posture.Process
if process.Path != nil {
p.Path = *process.Path
}
if process.WindowsPath != nil {
p.WindowsPath = *process.WindowsPath
}
processes = append(processes, p)
}
return &posture.ProcessCheck{
Processes: processes,
}
}

View File

@@ -43,6 +43,11 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error())
}
return nil
},
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
@@ -62,7 +67,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
user := server.NewAdminUser("test_user", "account_id")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
@@ -433,6 +438,43 @@ func TestPostureCheckUpdate(t *testing.T) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks Process Check",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"process_check": {
"processes": [
{
"path": "/usr/local/bin/netbird",
"windows_path": "C:\\ProgramData\\NetBird\\netbird.exe"
}
]
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
ProcessCheck: &api.ProcessCheck{
Processes: []api.Process{
{
Path: str("/usr/local/bin/netbird"),
WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"),
},
},
},
},
},
},
{
name: "Create Posture Checks Invalid Check",
requestType: http.MethodPost,
@@ -446,7 +488,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -461,7 +503,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -475,7 +517,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -489,7 +531,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"geo_location_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -663,11 +705,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Invalid Check",
@@ -682,7 +721,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -697,7 +736,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -711,7 +750,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -841,100 +880,3 @@ func TestPostureCheckUpdate(t *testing.T) {
})
}
}
func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
// empty name
err := validatePostureChecksUpdate(api.PostureCheckUpdate{})
assert.Error(t, err)
// empty checks
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"})
assert.Error(t, err)
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}})
assert.Error(t, err)
// not valid NbVersionCheck
nbVersionCheck := api.NBVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.Error(t, err)
// valid NbVersionCheck
nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.NoError(t, err)
// not valid OsVersionCheck
osVersionCheck := api.OSVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &api.MinVersionCheck{MinVersion: "14.2"},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid peer network range check
peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionAllow,
Ranges: []string{
"192.168.1.0/24", "10.0.0.0/8",
},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.NoError(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionDeny,
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
}

View File

@@ -75,7 +75,7 @@ var testingAccount = &server.Account{
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": server.NewAdminUser("test_user", "account_id"),
},
}

View File

@@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
adminUser := server.NewAdminUser("test_user", "account_id")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true)

View File

@@ -134,7 +134,8 @@ func Test_SyncProtocol(t *testing.T) {
// take the first registered peer as a base for the test. Total four.
key := *peers[0]
message, err := encryption.EncryptMessage(*serverKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
message, err := encryption.EncryptMessage(*serverKey, key, syncReq)
if err != nil {
t.Fatal(err)
return

View File

@@ -93,7 +93,8 @@ var _ = Describe("Management service", func() {
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
Expect(err).NotTo(HaveOccurred())
sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{
@@ -143,7 +144,7 @@ var _ = Describe("Management service", func() {
loginPeerWithValidSetupKey(serverPubKey, key1, client)
loginPeerWithValidSetupKey(serverPubKey, key2, client)
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred())
@@ -176,7 +177,7 @@ var _ = Describe("Management service", func() {
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key)
Expect(err).NotTo(HaveOccurred())
@@ -329,7 +330,7 @@ var _ = Describe("Management service", func() {
var clients []mgmtProto.ManagementService_SyncClient
for _, peer := range peers {
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{})
messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}})
Expect(err).NotTo(HaveOccurred())
encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer)
Expect(err).NotTo(HaveOccurred())
@@ -394,7 +395,8 @@ var _ = Describe("Management service", func() {
defer GinkgoRecover()
key, _ := wgtypes.GenerateKey()
loginPeerWithValidSetupKey(serverPubKey, key, client)
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.SyncRequest{})
syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}
encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq)
Expect(err).NotTo(HaveOccurred())
// open stream

View File

@@ -80,6 +80,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
GetPeerAppliedPostureChecksFunc func(peerKey string) ([]posture.Checks, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
@@ -609,6 +610,14 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
}
// GetPeerAppliedPostureChecks mocks GetPeerAppliedPostureChecks of the AccountManager interface
func (am *MockAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
if am.GetPeerAppliedPostureChecksFunc != nil {
return am.GetPeerAppliedPostureChecksFunc(peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerAppliedPostureChecks is not implemented")
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
if am.UpdateAccountSettingsFunc != nil {

View File

@@ -3,9 +3,10 @@ package mock_server
import (
"context"
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/proto"
)
type ManagementServiceServerMock struct {
@@ -17,6 +18,7 @@ type ManagementServiceServerMock struct {
IsHealthyFunc func(context.Context, *proto.Empty) (*proto.Empty, error)
GetDeviceAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
GetPKCEAuthorizationFlowFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error)
SyncMetaFunc func(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error)
}
func (m ManagementServiceServerMock) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
@@ -60,3 +62,10 @@ func (m ManagementServiceServerMock) GetPKCEAuthorizationFlow(ctx context.Contex
}
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (m ManagementServiceServerMock) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
if m.SyncMetaFunc != nil {
return m.SyncMetaFunc(ctx, req)
}
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}

View File

@@ -3,6 +3,7 @@ package server
import (
"fmt"
"net"
"slices"
"strings"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
@@ -19,6 +21,11 @@ import (
type PeerSync struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
// Meta is the system information passed by peer, must be always present
Meta nbpeer.PeerSystemMeta
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool
}
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
@@ -551,6 +558,18 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(account)
}
}
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if peerNotValid {
emptyMap := &NetworkMap{
@@ -866,7 +885,65 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
}
for _, peer := range peers {
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
update := toSyncResponse(am, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
}
}
// GetPeerAppliedPostureChecks returns posture checks that are applied to the peer.
func (am *DefaultAccountManager) GetPeerAppliedPostureChecks(peerKey string) ([]posture.Checks, error) {
account, err := am.Store.GetAccountByPeerPubKey(peerKey)
if err != nil {
log.Errorf("failed while getting peer %s: %v", peerKey, err)
return nil, err
}
peer, err := account.FindPeerByPubKey(peerKey)
if err != nil {
return nil, status.Errorf(status.NotFound, "peer is not registered")
}
if peer == nil {
return nil, nil
}
peerPostureChecks := make(map[string]posture.Checks)
for _, policy := range account.Policies {
if !policy.Enabled {
continue
}
outerLoop:
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup]
if !ok {
continue
}
// check if peer is in the rule source group
if slices.Contains(group.Peers, peer.ID) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureChecks := range account.PostureChecks {
if postureChecks.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureChecks
}
}
}
break outerLoop
}
}
}
}
postureChecksList := make([]posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
postureChecksList = append(postureChecksList, check)
}
return postureChecksList, nil
}

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
)
@@ -79,6 +80,13 @@ type Environment struct {
Platform string
}
// File is a file on the system.
type File struct {
Path string
Exist bool
ProcessIsRunning bool
}
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct { //nolint:revive
Hostname string
@@ -96,24 +104,22 @@ type PeerSystemMeta struct { //nolint:revive
SystemProductName string
SystemManufacturer string
Environment Environment `gorm:"serializer:json"`
Files []File `gorm:"serializer:json"`
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
if len(p.NetworkAddresses) != len(other.NetworkAddresses) {
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
for _, addr := range p.NetworkAddresses {
var found bool
for _, oAddr := range other.NetworkAddresses {
if addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP {
found = true
continue
}
}
if !found {
return false
}
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
@@ -133,6 +139,26 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
p.Environment.Platform == other.Environment.Platform
}
func (p PeerSystemMeta) isEmpty() bool {
return p.Hostname == "" &&
p.GoOS == "" &&
p.Kernel == "" &&
p.Core == "" &&
p.Platform == "" &&
p.OS == "" &&
p.OSVersion == "" &&
p.WtVersion == "" &&
p.UIVersion == "" &&
p.KernelVersion == "" &&
len(p.NetworkAddresses) == 0 &&
p.SystemSerialNumber == "" &&
p.SystemProductName == "" &&
p.SystemManufacturer == "" &&
p.Environment.Cloud == "" &&
p.Environment.Platform == "" &&
len(p.Files) == 0
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *Peer) AddedWithSSOLogin() bool {
return p.UserID != ""
@@ -168,6 +194,10 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
if meta.isEmpty() {
return false
}
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion

View File

@@ -1,8 +1,9 @@
package posture
import (
"fmt"
"errors"
"net/netip"
"regexp"
"github.com/hashicorp/go-version"
@@ -14,15 +15,21 @@ const (
OSVersionCheckName = "OSVersionCheck"
GeoLocationCheckName = "GeoLocationCheck"
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
ProcessCheckName = "ProcessCheck"
CheckActionAllow string = "allow"
CheckActionDeny string = "deny"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// Check represents an interface for performing a check on a peer.
type Check interface {
Check(peer nbpeer.Peer) (bool, error)
Name() string
Check(peer nbpeer.Peer) (bool, error)
Validate() error
}
type Checks struct {
@@ -48,6 +55,7 @@ type ChecksDefinition struct {
OSVersionCheck *OSVersionCheck `json:",omitempty"`
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
ProcessCheck *ProcessCheck `json:",omitempty"`
}
// Copy returns a copy of a checks definition.
@@ -93,6 +101,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
}
copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges)
}
if cd.ProcessCheck != nil {
processCheck := cd.ProcessCheck
cdCopy.ProcessCheck = &ProcessCheck{
Processes: make([]Process, len(processCheck.Processes)),
}
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
}
return cdCopy
}
@@ -133,50 +148,49 @@ func (pc *Checks) GetChecks() []Check {
if pc.Checks.PeerNetworkRangeCheck != nil {
checks = append(checks, pc.Checks.PeerNetworkRangeCheck)
}
if pc.Checks.ProcessCheck != nil {
checks = append(checks, pc.Checks.ProcessCheck)
}
return checks
}
// Validate checks the validity of a posture checks.
func (pc *Checks) Validate() error {
if check := pc.Checks.NBVersionCheck; check != nil {
if !isVersionValid(check.MinVersion) {
return fmt.Errorf("%s version: %s is not valid", check.Name(), check.MinVersion)
}
if pc.Name == "" {
return errors.New("posture checks name shouldn't be empty")
}
if osCheck := pc.Checks.OSVersionCheck; osCheck != nil {
if osCheck.Android != nil {
if !isVersionValid(osCheck.Android.MinVersion) {
return fmt.Errorf("%s android version: %s is not valid", osCheck.Name(), osCheck.Android.MinVersion)
}
}
if osCheck.Ios != nil {
if !isVersionValid(osCheck.Ios.MinVersion) {
return fmt.Errorf("%s ios version: %s is not valid", osCheck.Name(), osCheck.Ios.MinVersion)
}
}
if osCheck.Darwin != nil {
if !isVersionValid(osCheck.Darwin.MinVersion) {
return fmt.Errorf("%s darwin version: %s is not valid", osCheck.Name(), osCheck.Darwin.MinVersion)
}
}
if osCheck.Linux != nil {
if !isVersionValid(osCheck.Linux.MinKernelVersion) {
return fmt.Errorf("%s linux kernel version: %s is not valid", osCheck.Name(),
osCheck.Linux.MinKernelVersion)
}
}
if osCheck.Windows != nil {
if !isVersionValid(osCheck.Windows.MinKernelVersion) {
return fmt.Errorf("%s windows kernel version: %s is not valid", osCheck.Name(),
osCheck.Windows.MinKernelVersion)
}
}
// posture check should contain at least one check
if pc.Checks.NBVersionCheck == nil && pc.Checks.OSVersionCheck == nil &&
pc.Checks.GeoLocationCheck == nil && pc.Checks.PeerNetworkRangeCheck == nil && pc.Checks.ProcessCheck == nil {
return errors.New("posture checks shouldn't be empty")
}
if pc.Checks.NBVersionCheck != nil {
if err := pc.Checks.NBVersionCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.OSVersionCheck != nil {
if err := pc.Checks.OSVersionCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.GeoLocationCheck != nil {
if err := pc.Checks.GeoLocationCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.PeerNetworkRangeCheck != nil {
if err := pc.Checks.PeerNetworkRangeCheck.Validate(); err != nil {
return err
}
}
if pc.Checks.ProcessCheck != nil {
if err := pc.Checks.ProcessCheck.Validate(); err != nil {
return err
}
}
return nil
}

View File

@@ -150,9 +150,23 @@ func TestChecks_Validate(t *testing.T) {
checks Checks
expectedError bool
}{
{
name: "Empty name",
checks: Checks{},
expectedError: true,
},
{
name: "Empty checks",
checks: Checks{
Name: "Default",
Checks: ChecksDefinition{},
},
expectedError: true,
},
{
name: "Valid checks version",
checks: Checks{
Name: "default",
Checks: ChecksDefinition{
NBVersionCheck: &NBVersionCheck{
MinVersion: "0.25.0",
@@ -261,6 +275,14 @@ func TestChecks_Copy(t *testing.T) {
},
Action: CheckActionDeny,
},
ProcessCheck: &ProcessCheck{
Processes: []Process{
{
Path: "/Applications/NetBird.app/Contents/MacOS/netbird",
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
},
}
checkCopy := check.Copy()

View File

@@ -2,6 +2,7 @@ package posture
import (
"fmt"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -60,3 +61,28 @@ func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) {
func (g *GeoLocationCheck) Name() string {
return GeoLocationCheckName
}
func (g *GeoLocationCheck) Validate() error {
if g.Action == "" {
return fmt.Errorf("%s action shouldn't be empty", g.Name())
}
allowedActions := []string{CheckActionAllow, CheckActionDeny}
if !slices.Contains(allowedActions, g.Action) {
return fmt.Errorf("%s action is not valid", g.Name())
}
if len(g.Locations) == 0 {
return fmt.Errorf("%s locations shouldn't be empty", g.Name())
}
for _, loc := range g.Locations {
if loc.CountryCode == "" {
return fmt.Errorf("%s country code shouldn't be empty", g.Name())
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return fmt.Errorf("%s country code must be 2 letters (ISO 3166-1 alpha-2 format)", g.Name())
}
}
return nil
}

View File

@@ -236,3 +236,81 @@ func TestGeoLocationCheck_Check(t *testing.T) {
})
}
}
func TestGeoLocationCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check GeoLocationCheck
expectedError bool
}{
{
name: "Valid location list",
check: GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
},
expectedError: false,
},
{
name: "Invalid empty location list",
check: GeoLocationCheck{
Action: CheckActionDeny,
Locations: []Location{},
},
expectedError: true,
},
{
name: "Invalid empty country name",
check: GeoLocationCheck{
Action: CheckActionDeny,
Locations: []Location{
{
CityName: "Los Angeles",
},
},
},
expectedError: true,
},
{
name: "Invalid check action",
check: GeoLocationCheck{
Action: "unknownAction",
Locations: []Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
},
expectedError: true,
},
{
name: "Invalid country code",
check: GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{
{
CountryCode: "USA",
},
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,6 +1,8 @@
package posture
import (
"fmt"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
@@ -37,3 +39,13 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
func (n *NBVersionCheck) Name() string {
return NBVersionCheckName
}
func (n *NBVersionCheck) Validate() error {
if n.MinVersion == "" {
return fmt.Errorf("%s minimum version shouldn't be empty", n.Name())
}
if !isVersionValid(n.MinVersion) {
return fmt.Errorf("%s version: %s is not valid", n.Name(), n.MinVersion)
}
return nil
}

View File

@@ -108,3 +108,33 @@ func TestNBVersionCheck_Check(t *testing.T) {
})
}
}
func TestNBVersionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check NBVersionCheck
expectedError bool
}{
{
name: "Valid NBVersionCheck",
check: NBVersionCheck{MinVersion: "1.0"},
expectedError: false,
},
{
name: "Invalid NBVersionCheck",
check: NBVersionCheck{},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
type PeerNetworkRangeCheck struct {
@@ -52,3 +53,19 @@ func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
func (p *PeerNetworkRangeCheck) Name() string {
return PeerNetworkRangeCheckName
}
func (p *PeerNetworkRangeCheck) Validate() error {
if p.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []string{CheckActionAllow, CheckActionDeny}
if !slices.Contains(allowedActions, p.Action) {
return fmt.Errorf("%s action is not valid", p.Name())
}
if len(p.Ranges) == 0 {
return fmt.Errorf("%s network ranges shouldn't be empty", p.Name())
}
return nil
}

View File

@@ -147,3 +147,52 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
})
}
}
func TestNetworkCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check PeerNetworkRangeCheck
expectedError bool
}{
{
name: "Valid network range",
check: PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
expectedError: false,
},
{
name: "Invalid empty network range",
check: PeerNetworkRangeCheck{
Action: CheckActionDeny,
Ranges: []netip.Prefix{},
},
expectedError: true,
},
{
name: "Invalid check action",
check: PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,11 +1,13 @@
package posture
import (
"fmt"
"strings"
"github.com/hashicorp/go-version"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type MinVersionCheck struct {
@@ -48,6 +50,40 @@ func (c *OSVersionCheck) Name() string {
return OSVersionCheckName
}
func (c *OSVersionCheck) Validate() error {
emptyOS := c.Android == nil && c.Darwin == nil && c.Ios == nil &&
c.Linux == nil && c.Windows == nil
emptyMinVersion := c.Android != nil && c.Android.MinVersion == "" || c.Darwin != nil && c.Darwin.MinVersion == "" ||
c.Ios != nil && c.Ios.MinVersion == "" || c.Linux != nil && c.Linux.MinKernelVersion == "" || c.Windows != nil &&
c.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return fmt.Errorf("%s minimum version for at least one OS shouldn't be empty", c.Name())
}
if c.Android != nil && !isVersionValid(c.Android.MinVersion) {
return fmt.Errorf("%s android version: %s is not valid", c.Name(), c.Android.MinVersion)
}
if c.Ios != nil && !isVersionValid(c.Ios.MinVersion) {
return fmt.Errorf("%s ios version: %s is not valid", c.Name(), c.Ios.MinVersion)
}
if c.Darwin != nil && !isVersionValid(c.Darwin.MinVersion) {
return fmt.Errorf("%s darwin version: %s is not valid", c.Name(), c.Darwin.MinVersion)
}
if c.Linux != nil && !isVersionValid(c.Linux.MinKernelVersion) {
return fmt.Errorf("%s linux kernel version: %s is not valid", c.Name(),
c.Linux.MinKernelVersion)
}
if c.Windows != nil && !isVersionValid(c.Windows.MinKernelVersion) {
return fmt.Errorf("%s windows kernel version: %s is not valid", c.Name(),
c.Windows.MinKernelVersion)
}
return nil
}
func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)

View File

@@ -150,3 +150,79 @@ func TestOSVersionCheck_Check(t *testing.T) {
})
}
}
func TestOSVersionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check OSVersionCheck
expectedError bool
}{
{
name: "Valid linux kernel version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
},
expectedError: false,
},
{
name: "Valid linux and darwin version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &MinVersionCheck{MinVersion: "14.2"},
},
expectedError: false,
},
{
name: "Invalid empty check",
check: OSVersionCheck{},
expectedError: true,
},
{
name: "Invalid empty linux kernel version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{},
},
expectedError: true,
},
{
name: "Invalid empty linux kernel version with correct darwin version",
check: OSVersionCheck{
Linux: &MinKernelVersionCheck{},
Darwin: &MinVersionCheck{MinVersion: "14.2"},
},
expectedError: true,
},
{
name: "Valid windows kernel version",
check: OSVersionCheck{
Windows: &MinKernelVersionCheck{MinKernelVersion: "10.0"},
},
expectedError: false,
},
{
name: "Valid ios minimum version",
check: OSVersionCheck{
Ios: &MinVersionCheck{MinVersion: "13.0"},
},
expectedError: false,
},
{
name: "Invalid empty window version with valid ios minimum version",
check: OSVersionCheck{
Windows: &MinKernelVersionCheck{},
Ios: &MinVersionCheck{MinVersion: "13.0"},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,64 @@
package posture
import (
"fmt"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
type Process struct {
Path string
WindowsPath string
}
type ProcessCheck struct {
Processes []Process
}
var _ Check = (*ProcessCheck)(nil)
func (p *ProcessCheck) Check(peer nbpeer.Peer) (bool, error) {
peerActiveProcesses := make([]string, 0, len(peer.Meta.Files))
for _, file := range peer.Meta.Files {
if file.ProcessIsRunning {
peerActiveProcesses = append(peerActiveProcesses, file.Path)
}
}
switch peer.Meta.GoOS {
case "darwin", "linux":
for _, process := range p.Processes {
if process.Path == "" || !slices.Contains(peerActiveProcesses, process.Path) {
return false, nil
}
}
return true, nil
case "windows":
for _, process := range p.Processes {
if process.WindowsPath == "" || !slices.Contains(peerActiveProcesses, process.WindowsPath) {
return false, nil
}
}
return true, nil
default:
return false, fmt.Errorf("unsupported peer's operating system: %s", peer.Meta.GoOS)
}
}
func (p *ProcessCheck) Name() string {
return ProcessCheckName
}
func (p *ProcessCheck) Validate() error {
if len(p.Processes) == 0 {
return fmt.Errorf("%s processes shouldn't be empty", p.Name())
}
for _, process := range p.Processes {
if process.Path == "" && process.WindowsPath == "" {
return fmt.Errorf("%s path shouldn't be empty", p.Name())
}
}
return nil
}

View File

@@ -0,0 +1,305 @@
package posture
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/peer"
)
func TestProcessCheck_Check(t *testing.T) {
tests := []struct {
name string
input peer.Peer
check ProcessCheck
wantErr bool
isValid bool
}{
{
name: "darwin with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
Files: []peer.File{
{Path: "/Applications/process1.app", ProcessIsRunning: true},
{Path: "/Applications/process2.app", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/Applications/process1.app"},
{Path: "/Applications/process2.app"},
},
},
wantErr: false,
isValid: true,
},
{
name: "darwin with windows process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
Files: []peer.File{
{Path: "/Applications/process1.app", ProcessIsRunning: true},
{Path: "/Applications/process2.app", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: true,
},
{
name: "linux with matching no running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2", ProcessIsRunning: false},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with windows process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process1", ProcessIsRunning: true},
{Path: "/usr/bin/process2"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "linux with non-matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
Files: []peer.File{
{Path: "/usr/bin/process3"},
{Path: "/usr/bin/process4"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: false,
isValid: false,
},
{
name: "windows with matching running processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
{Path: "C:\\Program Files\\process1.exe", ProcessIsRunning: true},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process1.exe"},
},
},
wantErr: false,
isValid: true,
},
{
name: "windows with darwin process paths",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process1.exe"},
{Path: "C:\\Program Files\\process1.exe"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/Applications/process1.app"},
{Path: "/Applications/process2.app"},
},
},
wantErr: false,
isValid: false,
},
{
name: "windows with non-matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
Files: []peer.File{
{Path: "C:\\Program Files\\process3.exe"},
{Path: "C:\\Program Files\\process4.exe"},
},
},
},
check: ProcessCheck{
Processes: []Process{
{WindowsPath: "C:\\Program Files\\process1.exe"},
{WindowsPath: "C:\\Program Files\\process2.exe"},
},
},
wantErr: false,
isValid: false,
},
{
name: "unsupported ios operating system",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "ios",
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "C:\\Program Files\\process1.exe"},
{Path: "C:\\Program Files\\process2.exe"},
},
},
wantErr: true,
isValid: false,
},
{
name: "unsupported android operating system with matching processes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "android",
},
},
check: ProcessCheck{
Processes: []Process{
{Path: "/usr/bin/process1"},
{Path: "/usr/bin/process2"},
},
},
wantErr: true,
isValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.isValid, isValid)
})
}
}
func TestProcessCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check ProcessCheck
expectedError bool
}{
{
name: "Valid unix and windows processes",
check: ProcessCheck{
Processes: []Process{
{
Path: "/usr/local/bin/netbird",
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
expectedError: false,
},
{
name: "Valid unix process",
check: ProcessCheck{
Processes: []Process{
{
Path: "/usr/local/bin/netbird",
},
},
},
expectedError: false,
},
{
name: "Valid windows process",
check: ProcessCheck{
Processes: []Process{
{
WindowsPath: "C:\\ProgramData\\NetBird\\netbird.exe",
},
},
},
expectedError: false,
},
{
name: "Invalid empty processes",
check: ProcessCheck{
Processes: []Process{},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -52,7 +52,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
}
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.BadRequest, err.Error())
return status.Errorf(status.InvalidArgument, err.Error())
}
exists, uniqName := am.savePostureChecks(account, postureChecks)

View File

@@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
case <-ticker.C:
select {
case <-cancel:
log.Debugf("scheduled job %s was canceled, stop timer", ID)
log.Tracef("scheduled job %s was canceled, stop timer", ID)
ticker.Stop()
return
default:
log.Debugf("time to do a scheduled job %s", ID)
log.Tracef("time to do a scheduled job %s", ID)
}
runIn, reschedule := job()
if !reschedule {
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
log.Debugf("job %s is not scheduled to run again", ID)
log.Tracef("job %s is not scheduled to run again", ID)
ticker.Stop()
return
}
@@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
ticker.Reset(runIn)
}
case <-cancel:
log.Debugf("job %s was canceled, stopping timer", ID)
log.Tracef("job %s was canceled, stopping timer", ID)
ticker.Stop()
return
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
@@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
return unlock
}
func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error {
// Get the reflect.Value of the records slice
v := reflect.ValueOf(records)
if v.Kind() != reflect.Slice {
return fmt.Errorf("provided input is not a slice")
}
// Insert records in batches
for i := 0; i < v.Len(); i += batchSize {
end := i + batchSize
if end > v.Len() {
end = v.Len()
}
// Use reflect.Slice to get a slice of the records for the current batch
batch := v.Slice(i, end).Interface()
if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil {
return err
}
}
return nil
}
func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
// operate over a fresh copy as we will modify its fields
accCopy := account.Copy()
accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys))
for _, key := range accCopy.SetupKeys {
//we need an explicit reference to the account for gorm
key.AccountID = accCopy.Id
accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key)
}
for id, peer := range account.Peers {
accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers))
for id, peer := range accCopy.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
//we need an explicit reference to the account for gorm
peer.AccountID = accCopy.Id
accCopy.PeersG = append(accCopy.PeersG, *peer)
}
for id, user := range account.Users {
accCopy.UsersG = make([]User, 0, len(accCopy.Users))
for id, user := range accCopy.Users {
user.Id = id
//we need an explicit reference to the account for gorm
user.AccountID = accCopy.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
accCopy.UsersG = append(accCopy.UsersG, *user)
}
for id, group := range account.Groups {
accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups))
for id, group := range accCopy.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
//we need an explicit reference to the account for gorm
group.AccountID = accCopy.Id
accCopy.GroupsG = append(accCopy.GroupsG, *group)
}
for id, route := range account.Routes {
accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes))
for id, route := range accCopy.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
//we need an explicit reference to the account for gorm
route.AccountID = accCopy.Id
accCopy.RoutesG = append(accCopy.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups))
for id, ns := range accCopy.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
//we need an explicit reference to the account for gorm
ns.AccountID = accCopy.Id
accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
result = tx.Select(clause.Associations).Delete(accCopy)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
Clauses(clause.OnConflict{UpdateAll: true}).
Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG").
Create(accCopy)
if result.Error != nil {
return result.Error
}
return nil
const batchSize = 500
err := batchInsert(accCopy.PeersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.UsersG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.GroupsG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.RoutesG, batchSize, tx)
if err != nil {
return err
}
err = batchInsert(accCopy.SetupKeysG, batchSize, tx)
if err != nil {
return err
}
return batchInsert(accCopy.NameServerGroupsG, batchSize, tx)
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id)
return err
}
@@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
func (s *SqliteStore) DeleteAccount(account *Account) error {
start := time.Now()
account.UsersG = make([]User, 0, len(account.Users))
for id, user := range account.Users {
user.Id = id
//we need an explicit reference to an account as it is missing for some reason
user.AccountID = account.Id
user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs))
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {

View File

@@ -2,7 +2,12 @@ package server
import (
"fmt"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
route2 "github.com/netbirdio/netbird/route"
"math/rand"
"net"
"net/netip"
"path/filepath"
"runtime"
"testing"
@@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store := newSqliteStore(t)
account := newAccountWithId("account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 2000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peer := &nbpeer.Peer{
ID: peerID,
Key: peerID,
SetupKey: "",
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
route := &route2.Route{
ID: fmt.Sprintf("network-id-%d", n),
Description: "base route",
NetID: fmt.Sprintf("network-id-%d", n),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
NetworkType: route2.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
}
account.Routes[route.ID] = route
group = &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
Groups: []string{group.ID},
Primary: false,
Domains: nil,
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
err = store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Peers) != numPerAccount {
t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d",
numPerAccount, len(a.Peers))
return
}
if a != nil && len(a.Users) != numPerAccount+1 {
t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d",
numPerAccount+1, len(a.Users))
return
}
if a != nil && len(a.Routes) != numPerAccount {
t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d",
numPerAccount, len(a.Routes))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.NameServerGroups) != numPerAccount {
t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d",
numPerAccount, len(a.NameServerGroups))
return
}
if a != nil && len(a.SetupKeys) != numPerAccount+1 {
t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d",
numPerAccount+1, len(a.SetupKeys))
return
}
}
func TestSqlite_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
@@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) {
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
admin := account.Users["testuser"]
admin.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
HashedToken: "hashed token",
}}
err := store.SaveAccount(account)
require.NoError(t, err)
@@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
store := newSqliteStore(t)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user := NewAdminUser(testUserID, "account_id")
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
@@ -393,3 +539,12 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account)
}
func randomIPv4() net.IP {
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 4)
for i := range b {
b[i] = byte(rand.Intn(256))
}
return net.IP(b)
}

View File

@@ -180,9 +180,11 @@ func (u *User) Copy() *User {
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string,
accountID string) *User {
return &User{
Id: id,
Id: ID,
AccountID: accountID,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
@@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
func NewRegularUser(ID, accountID string) *User {
return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
func NewAdminUser(ID, accountID string) *User {
return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
func NewOwnerUser(ID, accountID string) *User {
return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI,
accountID)
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole,
serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
@@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID)
log.Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser

View File

@@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID)
account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID)
err := store.SaveAccount(account)
if err != nil {
@@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
@@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID)
user := NewAdminUser(mockUserID, mockAccountID)
assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID)
user = NewRegularUser(mockUserID, mockAccountID)
assert.False(t, user.HasAdminPower())
}
@@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id)
account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(account)
if err != nil {