Compare commits

...

35 Commits

Author SHA1 Message Date
bcmmbaga
72513d7522 Skip network map calculation when client serial matches current
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-01-27 23:03:43 +03:00
Zoltan Papp
a1f1bf1f19 Merge branch 'main' into feat/network-map-serial 2025-12-18 15:59:53 +01:00
Zoltan Papp
b5dec3df39 Track network serial in engine 2025-12-18 15:27:49 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
Maycon Santos
cb6b086164 [client] Reorder subsystem shutdown so peer removal goes first (#4914)
Remove peers before DNS and routes
2025-12-04 21:01:22 +01:00
Zoltan Papp
71b6855e09 [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891)
* Fix engine shutdown deadlock and message handling races

- Release syncMsgMux before waiting for shutdownWg to prevent deadlock
- Check context inside lock in handleSync and receiveSignalEvents
- Prevents nil pointer access when messages arrive during engine stop
2025-12-04 19:51:50 +01:00
Viktor Liu
9bdc4908fb [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) 2025-12-04 19:16:38 +01:00
Bethuel Mmbaga
031ab11178 [client] Remove select account prompt (#4912)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-12-04 14:57:29 +01:00
Zoltan Papp
d2e48d4f5e [relay] Use instanceURL instead of Exposed address. (#4905)
Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields.
2025-12-03 18:42:53 +01:00
Bethuel Mmbaga
27dd97c9c4 [management] Add support to disable geolocation service (#4901) 2025-12-03 14:45:59 +03:00
Maycon Santos
e87b4ace11 [client] Add sleep state tracking to handle wakeup/sleep events reliably (#4894)
Adds a new NotifyOSLifecycle RPC and server handler to centralize OS sleep/wake handling, introduces Server.sleepTriggeredDown for coordination, updates client UI to call the new RPC, and adjusts the internal sleep event enum zero-value semantics.
2025-12-03 11:53:39 +01:00
Pascal Fischer
a232cf614c [management] record pat usage metrics (#4888) 2025-12-02 18:31:59 +01:00
Maycon Santos
a293f760af [client] Add conditional peer removal logic during shutdown (#4897) 2025-12-02 16:30:15 +01:00
Pascal Fischer
10e9cf8c62 [management] update management integrations (#4895) 2025-12-02 14:13:01 +01:00
Pascal Fischer
7193bd2da7 [management] Refactor network map controller (#4789) 2025-12-02 12:34:28 +01:00
Bethuel Mmbaga
52948ccd61 [management] Add user created activity event (#4893) 2025-12-02 14:17:59 +03:00
Fahri Shihab
4b77359042 [management] Groups API with name query parameter (#4831) 2025-12-01 16:57:42 +01:00
Zoltan Papp
387d43bcc1 [client, management] Add OAuth select_account prompt support to PKCE flow (#4880)
* Add OAuth select_account prompt support to PKCE flow

Extends LoginFlag enum with select_account options to enable
multi-account selection during authentication. This allows users
to choose which account to use when multiple accounts have active
sessions with the identity provider.

The new flags are backward compatible - existing LoginFlag values
(0=prompt login, 1=max_age=0) retain their original behavior.
2025-12-01 14:25:52 +01:00
Zoltan Papp
e47d815dd2 Fix IsAnotherProcessRunning (#4858)
Compare the exact process name rather than searching for a substring of the full path
2025-12-01 14:16:03 +01:00
shuuri-labs
cb83b7c0d3 [relay] use exposed address for healthcheck TLS validation (#4872)
* fix(relay): use exposed address for healthcheck TLS validation

Healthcheck was using listen address (0.0.0.0) instead of exposed address
(domain name) for certificate validation, causing validation to always fail.

Now correctly uses the exposed address where the TLS certificate is valid,
matching real client connection behavior.

* - store exposedAddress directly in Relay struct instead of parsing on every call
- remove unused parseHostPort() function
- remove unused ListenAddress() method from ServiceChecker interface
- improve error logging with address context

* [relay/healthcheck] Remove QUIC health check logic, update WebSocket validation flow

Refactored health check logic by removing QUIC-specific connection validation and simplifying logic for WebSocket protocol. Adjusted certificate validation flow and improved handling of exposed addresses.

* [relay/healthcheck] Fix certificate validation status during health check

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-11-28 21:53:53 +01:00
Zoltan Papp
ddcd182859 [client] Sleep detection on macOS (#4859)
A macOS-specific sleep detection mechanism using IOKit and CoreFoundation via cgo is introduced, with a fallback implementation for unsupported platforms. A public Service wrapper provides an event-driven API translating system sleep/wake events into gRPC calls. The UI client integrates sleep detection to manage connectivity state based on system sleep status.
2025-11-28 17:26:22 +01:00
Hakan Sariman
20f5f00635 [client] Add unit tests for engine synchronization and Info flag copying
- Introduced tests for the Engine's handleSync method to verify behavior when SkipNetworkMapUpdate is true and when NetworkMap is nil.
- Added a test for the Info struct to ensure correct copying of flag values from one instance to another, while preserving unrelated fields.
2025-10-17 10:03:07 +03:00
Hakan Sariman
fc141cf3a3 [client] Refactor lastNetworkMapSerial handling in GrpcClient
- Removed atomic operations for lastNetworkMapSerial and replaced them with mutex-based methods for thread-safe access.
2025-09-29 18:49:23 +07:00
Hakan Sariman
d0c65fa08e [client] Add skipNetworkMapUpdate field to SyncResponse for conditional updates 2025-09-29 18:28:14 +07:00
Hakan Sariman
f241bfa339 Refactor flag setting in Info struct to use CopyFlagsFrom method 2025-09-29 15:38:35 +07:00
Hakan Sariman
4b2cd97d5f [client] Enhance SyncRequest with NetworkMap serial tracking
- Added `networkMapSerial` field to `SyncRequest` for tracking the last known network map serial number.
- Updated `GrpcClient` to store and utilize the last network map serial during sync operations, optimizing synchronization processes.
- Improved handling of system info updates to ensure accurate metadata is sent with sync requests.
2025-09-25 19:28:35 +07:00
118 changed files with 3492 additions and 1373 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy
```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto"
@@ -24,8 +26,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -27,7 +27,11 @@ import (
)
const (
tableNat = "nat"
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
@@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Warnf("failed to load filter table, skipping accept rules: %v", err)
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
@@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
@@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
@@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error {
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables()
return r.acceptFilterRulesNftables(r.filterTable)
}
return r.acceptFilterRulesIptables(ipt)
@@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
@@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
@@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
Data: intf,
},
}
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return r.conn.Flush()
}
func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptFilterRulesNftables()
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
if chain.Table.Name != table.Name {
continue
}
@@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
return fmt.Errorf("get rules: %v", err)
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
return chains
}
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
return nil
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
@@ -104,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
}
if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
switch p.providerConfig.LoginFlag {
case common.LoginFlagPromptLogin:
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
case common.LoginFlagMaxAge0:
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}

View File

@@ -23,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
expectContains []string
}{
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
name: "Prompt login",
loginFlag: mgm.LoginFlagPromptLogin,
expectContains: []string{promptLogin},
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
name: "Max age 0",
loginFlag: mgm.LoginFlagMaxAge0,
expectContains: []string{maxAge0},
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
loginFlag: mgm.LoginFlagPromptLogin,
disablePromptLogin: true,
expectContains: []string{},
},
{
name: "None flag should not add parameters",
loginFlag: mgm.LoginFlagNone,
expectContains: []string{},
},
}
@@ -53,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
LoginFlag: tc.loginFlag,
DisablePromptLogin: tc.disablePromptLogin,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
@@ -63,11 +70,8 @@ func TestPromptLogin(t *testing.T) {
t.Fatalf("Failed to request auth info: %v", err)
}
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
for _, expected := range tc.expectContains {
require.Contains(t, authInfo.VerificationURIComplete, expected)
}
})
}

View File

@@ -273,11 +273,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -293,12 +294,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
<-engineCtx.Done()
c.engineMutex.Lock()
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)

View File

@@ -280,7 +280,6 @@ func (e *Engine) Stop() error {
return nil
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
@@ -298,9 +297,6 @@ func (e *Engine) Stop() error {
e.cleanupSSHConfig()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err)
@@ -308,24 +304,29 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil
}
e.stopDNSForwarder()
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
log.Errorf("failed to remove all peers: %s", err)
}
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
e.stopDNSForwarder()
// stop/restore DNS after peers are closed but before interface goes down
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
@@ -337,16 +338,18 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer stateCancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
if err := e.stateManager.Stop(stateCtx); err != nil {
log.Errorf("failed to stop state manager: %v", err)
}
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -432,8 +435,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err)
}
err := e.rpManager.Run()
if err != nil {
if err := e.rpManager.Run(); err != nil {
return fmt.Errorf("run rosenpass manager: %w", err)
}
}
@@ -485,6 +487,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -750,6 +753,11 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -789,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
nm := update.GetNetworkMap()
if nm == nil {
if nm == nil || update.SkipNetworkMapUpdate {
return nil
}
@@ -955,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1369,6 +1377,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)

View File

@@ -0,0 +1,79 @@
package internal
import (
"context"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun199",
WgAddr: "100.70.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
// Precondition
if engine.networkSerial != 0 {
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
}
resp := &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
SkipNetworkMapUpdate: true,
}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
if engine.networkSerial != 0 {
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
}
}
// Ensures handleSync exits early when NetworkMap is nil
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun198",
WgAddr: "100.70.0.2/24",
WgPrivateKey: key,
WgPort: 33101,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
}

View File

@@ -30,11 +30,12 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -54,7 +55,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -631,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -1628,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -0,0 +1,218 @@
//go:build darwin && !ios
package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import (
"context"
"fmt"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
)
//export sleepCallbackBridge
func sleepCallbackBridge() {
log.Info("sleepCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeSleep)
}
}
//export resumedCallbackBridge
func resumedCallbackBridge() {
log.Info("resumedCallbackBridge event triggered")
}
//export suspendedCallbackBridge
func suspendedCallbackBridge() {
log.Info("suspendedCallbackBridge event triggered")
}
//export poweredOnCallbackBridge
func poweredOnCallbackBridge() {
log.Info("poweredOnCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeWakeUp)
}
}
type Detector struct {
callback func(event EventType)
ctx context.Context
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
}
func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
if _, exists := serviceRegistry[d]; exists {
return fmt.Errorf("detector service already registered")
}
d.callback = callback
d.ctx, d.cancel = context.WithCancel(context.Background())
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
return nil
}
serviceRegistry[d] = struct{}{}
// CFRunLoop must run on a single fixed OS thread
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.registerNotifications()
}()
log.Info("sleep detection service started on macOS")
return nil
}
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
// and the runloop is stopped and cleaned up.
func (d *Detector) Deregister() error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
_, exists := serviceRegistry[d]
if !exists {
return nil
}
// cancel and remove this detector
d.cancel()
delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 {
return nil
}
log.Info("sleep detection service stopping (deregister)")
// Deregister IOKit notifications, stop runloop, and free resources
C.unregisterNotifications()
return nil
}
func (d *Detector) triggerCallback(event EventType) {
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
cb := d.callback
go func(callback func(event EventType)) {
log.Info("sleep detection event fired")
callback(event)
close(doneChan)
}(cb)
select {
case <-doneChan:
case <-d.ctx.Done():
case <-timeout.C:
log.Warnf("sleep callback timed out")
}
}

View File

@@ -0,0 +1,9 @@
//go:build !darwin || ios
package sleep
import "fmt"
func NewDetector() (detector, error) {
return nil, fmt.Errorf("sleep not supported on this platform")
}

View File

@@ -0,0 +1,37 @@
package sleep
var (
EventTypeUnknown EventType = 0
EventTypeSleep EventType = 1
EventTypeWakeUp EventType = 2
)
type EventType int
type detector interface {
Register(callback func(eventType EventType)) error
Deregister() error
}
type Service struct {
detector detector
}
func New() (*Service, error) {
d, err := NewDetector()
if err != nil {
return nil, err
}
return &Service{
detector: d,
}, nil
}
func (s *Service) Register(callback func(eventType EventType)) error {
return s.detector.Register(callback)
}
func (s *Service) Deregister() error {
return s.detector.Deregister()
}

View File

@@ -1,9 +1,12 @@
//go:build ios
package NetBirdSDK
import (
"context"
"fmt"
"net/netip"
"os"
"sort"
"strings"
"sync"
@@ -90,7 +93,8 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string) error {
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
@@ -433,3 +437,19 @@ func toNetIDs(routes []string) []route.NetID {
}
return netIDs
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
log.Debugf("Env variable %s's value is currently: %s", k, os.Getenv(k))
log.Debugf("Setting env variable %s: %s", k, v)
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
} else {
log.Debugf("Env variable %s was set successfully", k)
}
}
}

View File

@@ -0,0 +1,34 @@
//go:build ios
package NetBirdSDK
import "github.com/netbirdio/netbird/client/internal/peer"
// EnvList is an exported struct to be bound by gomobile
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}
// GetEnvKeyNBForceRelay Exports the environment variable for the iOS client
func GetEnvKeyNBForceRelay() string {
return peer.EnvKeyNBForceRelay
}

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import _ "golang.org/x/mobile/bind"

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// PeerInfo describe information about the peers. It designed for the UI usage

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
import (

View File

@@ -1,3 +1,5 @@
//go:build ios
package NetBirdSDK
// RoutesSelectionInfoCollection made for Java layer to get non default types as collection

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ service DaemonService {
// Status of the service.
rpc Status(StatusRequest) returns (StatusResponse) {}
// Down engine work in the daemon.
// Down stops engine work in the daemon.
rpc Down(DownRequest) returns (DownResponse) {}
// GetConfig of the daemon.
@@ -93,9 +93,26 @@ service DaemonService {
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
}
message OSLifecycleRequest {
// avoid collision with loglevel enum
enum CycleType {
UNKNOWN = 0;
SLEEP = 1;
WAKEUP = 2;
}
CycleType type = 1;
}
message OSLifecycleResponse {}
message LoginRequest {
// setupKey netbird setup key.
string setupKey = 1;

View File

@@ -27,7 +27,7 @@ type DaemonServiceClient interface {
Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error)
// Status of the service.
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
// Down engine work in the daemon.
// Down stops engine work in the daemon.
Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error)
@@ -70,6 +70,7 @@ type DaemonServiceClient interface {
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
}
type daemonServiceClient struct {
@@ -382,6 +383,15 @@ func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTToken
return out, nil
}
func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error) {
out := new(OSLifecycleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/NotifyOSLifecycle", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -395,7 +405,7 @@ type DaemonServiceServer interface {
Up(context.Context, *UpRequest) (*UpResponse, error)
// Status of the service.
Status(context.Context, *StatusRequest) (*StatusResponse, error)
// Down engine work in the daemon.
// Down stops engine work in the daemon.
Down(context.Context, *DownRequest) (*DownResponse, error)
// GetConfig of the daemon.
GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error)
@@ -438,6 +448,7 @@ type DaemonServiceServer interface {
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -538,6 +549,9 @@ func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *Request
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1112,6 +1126,24 @@ func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, d
return interceptor(ctx, in, info, handler)
}
func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(OSLifecycleRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/NotifyOSLifecycle",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).NotifyOSLifecycle(ctx, req.(*OSLifecycleRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1239,6 +1271,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -0,0 +1,77 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
)
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
return s.handleWakeUp(callerCtx)
case proto.OSLifecycleRequest_SLEEP:
return s.handleSleep(callerCtx)
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
return &proto.OSLifecycleResponse{}, nil
}
// handleWakeUp processes a wake-up event by triggering the Up command if the system was previously put to sleep.
// It resets the sleep state and logs the process. Returns a response or an error if the Up command fails.
func (s *Server) handleWakeUp(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
if !s.sleepTriggeredDown.Load() {
log.Info("skipping up because wasn't sleep down")
return &proto.OSLifecycleResponse{}, nil
}
// avoid other wakeup runs if sleep didn't make the computer sleep
s.sleepTriggeredDown.Store(false)
log.Info("running up after wake up")
_, err := s.Up(callerCtx, &proto.UpRequest{})
if err != nil {
log.Errorf("running up failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
log.Info("running up command executed successfully")
return &proto.OSLifecycleResponse{}, nil
}
// handleSleep handles the sleep event by initiating a "down" sequence if the system is in a connected or connecting state.
func (s *Server) handleSleep(callerCtx context.Context) (*proto.OSLifecycleResponse, error) {
s.mutex.Lock()
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, err
}
if status != internal.StatusConnecting && status != internal.StatusConnected {
log.Infof("skipping setting the agent down because status is %s", status)
s.mutex.Unlock()
return &proto.OSLifecycleResponse{}, nil
}
s.mutex.Unlock()
log.Info("running down after system started sleeping")
_, err = s.Down(callerCtx, &proto.DownRequest{})
if err != nil {
log.Errorf("running down failed: %v", err)
return &proto.OSLifecycleResponse{}, err
}
s.sleepTriggeredDown.Store(true)
log.Info("running down executed successfully")
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -0,0 +1,219 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
func newTestServer() *Server {
ctx := internal.CtxInitState(context.Background())
return &Server{
rootCtx: ctx,
statusRecorder: peer.NewRecorder(""),
}
}
func TestNotifyOSLifecycle_WakeUp_SkipsWhenNotSleepTriggered(t *testing.T) {
s := newTestServer()
// sleepTriggeredDown is false by default
assert.False(t, s.sleepTriggeredDown.Load())
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusIdle(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is Idle")
}
func TestNotifyOSLifecycle_Sleep_SkipsWhenStatusNeedsLogin(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusNeedsLogin)
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load(), "flag should remain false when status is NeedsLogin")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnecting(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnecting)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connecting")
}
func TestNotifyOSLifecycle_Sleep_SetsFlag_WhenConnected(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusConnected)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.NotifyOSLifecycle(ctx, &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_SLEEP,
})
require.NoError(t, err)
assert.NotNil(t, resp, "handleSleep returns not nil response on success")
assert.True(t, s.sleepTriggeredDown.Load(), "flag should be set after sleep when connected")
}
func TestNotifyOSLifecycle_WakeUp_ResetsFlag(t *testing.T) {
s := newTestServer()
// Manually set the flag to simulate prior sleep down
s.sleepTriggeredDown.Store(true)
// WakeUp will try to call Up which fails without proper setup, but flag should reset first
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load(), "flag should be reset after WakeUp attempt")
}
func TestNotifyOSLifecycle_MultipleWakeUpCalls(t *testing.T) {
s := newTestServer()
// First wakeup without prior sleep - should be no-op
resp, err := s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
// Simulate prior sleep
s.sleepTriggeredDown.Store(true)
// First wakeup after sleep - should reset flag
_, _ = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
assert.False(t, s.sleepTriggeredDown.Load())
// Second wakeup - should be no-op
resp, err = s.NotifyOSLifecycle(context.Background(), &proto.OSLifecycleRequest{
Type: proto.OSLifecycleRequest_WAKEUP,
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
}
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
s := newTestServer()
resp, err := s.handleWakeUp(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
s := newTestServer()
s.sleepTriggeredDown.Store(true)
// Even if Up fails, flag should be reset
_, _ = s.handleWakeUp(context.Background())
assert.False(t, s.sleepTriggeredDown.Load(), "flag must be reset before calling Up")
}
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Idle", internal.StatusIdle},
{"NeedsLogin", internal.StatusNeedsLogin},
{"LoginFailed", internal.StatusLoginFailed},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
resp, err := s.handleSleep(context.Background())
require.NoError(t, err)
require.NotNil(t, resp)
assert.False(t, s.sleepTriggeredDown.Load())
})
}
}
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
tests := []struct {
name string
status internal.StatusType
}{
{"Connecting", internal.StatusConnecting},
{"Connected", internal.StatusConnected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := newTestServer()
state := internal.CtxGetState(s.rootCtx)
state.Set(tt.status)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.actCancel = cancel
resp, err := s.handleSleep(ctx)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, s.sleepTriggeredDown.Load())
})
}
}

View File

@@ -85,6 +85,9 @@ type Server struct {
profilesDisabled bool
updateSettingsDisabled bool
// sleepTriggeredDown holds a state indicated if the sleep handler triggered the last client down
sleepTriggeredDown atomic.Bool
jwtCache *jwtCache
}
@@ -819,6 +822,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
defer s.mutex.Unlock()
if err := s.cleanupConnection(); err != nil {
// todo review to update the status in case any type of error
log.Errorf("failed to shut down properly: %v", err)
return nil, err
}
@@ -911,6 +915,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
}
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
// todo review to update the status in case any type of error
log.Errorf("failed to cleanup connection: %v", err)
return nil, err
}

View File

@@ -17,11 +17,12 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -35,7 +36,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -42,6 +42,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// detectUtilLinuxLogin always returns false on JS/WASM
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")

View File

@@ -10,6 +10,7 @@ import (
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"sync"
"syscall"
@@ -75,6 +76,29 @@ func (s *Server) detectSuPtySupport(ctx context.Context) bool {
return supported
}
// detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils).
// util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent.
// See https://bugs.debian.org/1078023 for details.
func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
if runtime.GOOS != "linux" {
return false
}
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "login", "--version")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("login --version failed (likely shadow-utils): %v", err)
return false
}
isUtilLinux := strings.Contains(string(output), "util-linux")
log.Debugf("util-linux login detected: %v", isUtilLinux)
return isUtilLinux
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
@@ -144,7 +168,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu
return false
}
logger.Infof("starting interactive shell: %s", execCmd.Path)
logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " "))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}

View File

@@ -383,6 +383,11 @@ func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// detectUtilLinuxLogin always returns false on Windows
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()

View File

@@ -138,7 +138,8 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
suSupportsPty bool
suSupportsPty bool
loginIsUtilLinux bool
}
type JWTConfig struct {
@@ -193,6 +194,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
}
s.suSupportsPty = s.detectSuPtySupport(ctx)
s.loginIsUtilLinux = s.detectUtilLinuxLogin(ctx)
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {

View File

@@ -87,11 +87,8 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
switch runtime.GOOS {
case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
p, a := s.getLinuxLoginCmd(loginPath, username, addrPort.Addr().String())
return p, a, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default:
@@ -99,7 +96,37 @@ func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []st
}
}
// fileExists checks if a file exists (helper for login command logic)
// getLinuxLoginCmd returns the login command for Linux systems.
// Handles differences between util-linux and shadow-utils login implementations.
func (s *Server) getLinuxLoginCmd(loginPath, username, remoteIP string) (string, []string) {
// Special handling for Arch Linux without /etc/pam.d/remote
var loginArgs []string
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
loginArgs = []string{"-f", username, "-p"}
} else {
loginArgs = []string{"-f", username, "-h", remoteIP, "-p"}
}
// util-linux login requires setsid -c to create a new session and set the
// controlling terminal. Without this, vhangup() kills the parent process.
// See https://bugs.debian.org/1078023 for details.
// TODO: handle this via the executor using syscall.Setsid() + TIOCSCTTY + syscall.Exec()
// to avoid external setsid dependency.
if !s.loginIsUtilLinux {
return loginPath, loginArgs
}
setsidPath, err := exec.LookPath("setsid")
if err != nil {
log.Warnf("setsid not available but util-linux login detected, login may fail: %v", err)
return loginPath, loginArgs
}
args := append([]string{"-w", "-c", loginPath}, loginArgs...)
return setsidPath, args
}
// fileExists checks if a file exists
func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil

View File

@@ -120,6 +120,26 @@ func (i *Info) SetFlags(
}
}
func (i *Info) CopyFlagsFrom(other *Info) {
i.SetFlags(
other.RosenpassEnabled,
other.RosenpassPermissive,
&other.ServerSSHAllowed,
other.DisableClientRoutes,
other.DisableServerRoutes,
other.DisableDNS,
other.DisableFirewall,
other.BlockLANAccess,
other.BlockInbound,
other.LazyConnectionEnabled,
&other.EnableSSHRoot,
&other.EnableSSHSFTP,
&other.EnableSSHLocalPortForwarding,
&other.EnableSSHRemotePortForwarding,
&other.DisableSSHAuth,
)
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)

View File

@@ -8,6 +8,90 @@ import (
"google.golang.org/grpc/metadata"
)
func TestInfo_CopyFlagsFrom(t *testing.T) {
origin := &Info{}
serverSSHAllowed := true
enableSSHRoot := true
enableSSHSFTP := false
enableSSHLocalPortForwarding := true
enableSSHRemotePortForwarding := false
disableSSHAuth := true
origin.SetFlags(
true, // RosenpassEnabled
false, // RosenpassPermissive
&serverSSHAllowed,
true, // DisableClientRoutes
false, // DisableServerRoutes
true, // DisableDNS
false, // DisableFirewall
true, // BlockLANAccess
false, // BlockInbound
true, // LazyConnectionEnabled
&enableSSHRoot,
&enableSSHSFTP,
&enableSSHLocalPortForwarding,
&enableSSHRemotePortForwarding,
&disableSSHAuth,
)
got := &Info{}
got.CopyFlagsFrom(origin)
if got.RosenpassEnabled != true {
t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled)
}
if got.RosenpassPermissive != false {
t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive)
}
if got.ServerSSHAllowed != true {
t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed)
}
if got.DisableClientRoutes != true {
t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes)
}
if got.DisableServerRoutes != false {
t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes)
}
if got.DisableDNS != true {
t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS)
}
if got.DisableFirewall != false {
t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall)
}
if got.BlockLANAccess != true {
t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess)
}
if got.BlockInbound != false {
t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound)
}
if got.LazyConnectionEnabled != true {
t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled)
}
if got.EnableSSHRoot != true {
t.Fatalf("EnableSSHRoot not copied: got %v", got.EnableSSHRoot)
}
if got.EnableSSHSFTP != false {
t.Fatalf("EnableSSHSFTP not copied: got %v", got.EnableSSHSFTP)
}
if got.EnableSSHLocalPortForwarding != true {
t.Fatalf("EnableSSHLocalPortForwarding not copied: got %v", got.EnableSSHLocalPortForwarding)
}
if got.EnableSSHRemotePortForwarding != false {
t.Fatalf("EnableSSHRemotePortForwarding not copied: got %v", got.EnableSSHRemotePortForwarding)
}
if got.DisableSSHAuth != true {
t.Fatalf("DisableSSHAuth not copied: got %v", got.DisableSSHAuth)
}
// ensure CopyFlagsFrom does not touch unrelated fields
origin.Hostname = "host-a"
got.Hostname = "host-b"
got.CopyFlagsFrom(origin)
if got.Hostname != "host-b" {
t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname)
}
}
func Test_LocalWTVersion(t *testing.T) {
got := GetInfo(context.TODO())
want := "development"

View File

@@ -38,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/sleep"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
@@ -209,10 +210,11 @@ var iconConnectedDot []byte
var iconDisconnectedDot []byte
type serviceClient struct {
ctx context.Context
cancel context.CancelFunc
addr string
conn proto.DaemonServiceClient
ctx context.Context
cancel context.CancelFunc
addr string
conn proto.DaemonServiceClient
connLock sync.Mutex
eventHandler *eventHandler
@@ -1098,6 +1100,9 @@ func (s *serviceClient) onTrayReady() {
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)
// Start sleep detection listener
go s.startSleepListener()
}
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
@@ -1134,6 +1139,8 @@ func (s *serviceClient) onTrayExit() {
// getSrvClient connection to the service.
func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonServiceClient, error) {
s.connLock.Lock()
defer s.connLock.Unlock()
if s.conn != nil {
return s.conn, nil
}
@@ -1156,6 +1163,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil
}
// startSleepListener initializes the sleep detection service and listens for sleep events
func (s *serviceClient) startSleepListener() {
sleepService, err := sleep.New()
if err != nil {
log.Warnf("%v", err)
return
}
if err := sleepService.Register(s.handleSleepEvents); err != nil {
log.Errorf("failed to start sleep detection: %v", err)
return
}
log.Info("sleep detection service initialized")
// Cleanup on context cancellation
go func() {
<-s.ctx.Done()
log.Info("stopping sleep event listener")
if err := sleepService.Deregister(); err != nil {
log.Errorf("failed to deregister sleep detection: %v", err)
}
}()
}
// handleSleepEvents sends a sleep notification to the daemon via gRPC
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
conn, err := s.getSrvClient(0)
if err != nil {
log.Errorf("failed to get daemon client for sleep notification: %v", err)
return
}
req := &proto.OSLifecycleRequest{}
switch event {
case sleep.EventTypeWakeUp:
log.Infof("handle wakeup event: %v", event)
req.Type = proto.OSLifecycleRequest_WAKEUP
case sleep.EventTypeSleep:
log.Infof("handle sleep event: %v", event)
req.Type = proto.OSLifecycleRequest_SLEEP
default:
log.Infof("unknown event: %v", event)
return
}
_, err = conn.NotifyOSLifecycle(s.ctx, req)
if err != nil {
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
return
}
log.Info("successfully notified daemon about os lifecycle")
}
// setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {

View File

@@ -28,7 +28,8 @@ func IsAnotherProcessRunning() (int32, bool, error) {
continue
}
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
runningProcessName := strings.ToLower(filepath.Base(runningProcessPath))
if runningProcessName == processName && isProcessOwnedByCurrentUser(p) {
return p.Pid, true, nil
}
}

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -19,6 +19,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -42,6 +43,7 @@ type Controller struct {
accountManagerMetrics *telemetry.AccountManagerMetrics
peersUpdateManager network_map.PeersUpdateManager
settingsManager settings.Manager
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
@@ -70,7 +72,7 @@ type bufferUpdate struct {
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller {
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
if err != nil {
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
@@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
dnsDomain: dnsDomain,
config: config,
proxyController: proxyController,
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
@@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
}
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
}
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
}
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
c.peersUpdateManager.CloseChannel(ctx, peerID)
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
return
}
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
}
func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams()
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
@@ -366,55 +394,26 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountId)
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
return nil, nil, nil, 0, err
}
peers, err := c.repo.GetAccountPeers(ctx, accountId)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerId)
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if clientSerial > 0 && clientSerial == network.CurrentSerial() {
log.WithContext(ctx).Debugf("client serial %d matches current serial, skipping network map calculation", clientSerial)
return peer, nil, nil, 0, nil
}
var account *types.Account
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
@@ -698,35 +697,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
c.UpdatePeerInNetworkMapCache(accountId, peer)
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
return nil
}
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
for _, peerID := range peerIDs {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountID)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
for _, peerID := range peerIDs {
c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
@@ -778,10 +825,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return networkMap, nil
}
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) IsConnected(peerID string) bool {
return c.peersUpdateManager.HasChannel(peerID)
}

View File

@@ -12,6 +12,8 @@ type Repository interface {
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
}
type repository struct {
@@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
return r.store.GetAccountByPeerID(ctx, peerID)
}
func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
}
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}

View File

@@ -24,16 +24,16 @@ type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
DeletePeer(ctx context.Context, accountId string, peerId string) error
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
DisconnectPeers(ctx context.Context, peerIDs []string)
IsConnected(peerID string) bool
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
}

View File

@@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
}
// DeletePeer mocks base method.
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
ret := m.ctrl.Call(m, "CountStreams")
ret0, _ := ret[0].(int)
return ret0
}
// DeletePeer indicates an expected call of DeletePeer.
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
// CountStreams indicates an expected call of CountStreams.
func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams))
}
// DisconnectPeers mocks base method.
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs)
}
// DisconnectPeers indicates an expected call of DisconnectPeers.
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs)
}
// GetDNSDomain mocks base method.
@@ -113,9 +113,9 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer, clientSerial uint64) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p, clientSerial)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
@@ -125,63 +125,78 @@ func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequires
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p, clientSerial any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p, clientSerial)
}
// IsConnected mocks base method.
func (m *MockController) IsConnected(peerID string) bool {
// OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsConnected", peerID)
ret0, _ := ret[0].(bool)
return ret0
ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID)
ret0, _ := ret[0].(chan *UpdateMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsConnected indicates an expected call of IsConnected.
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
// OnPeerConnected indicates an expected call of OnPeerConnected.
func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
}
// OnPeerAdded mocks base method.
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
// OnPeerDisconnected mocks base method.
func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID)
}
// OnPeerDisconnected indicates an expected call of OnPeerDisconnected.
func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID)
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerAdded indicates an expected call of OnPeerAdded.
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
}
// OnPeerDeleted mocks base method.
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
}
// OnPeerUpdated mocks base method.
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
}
// StartWarmup mocks base method.

View File

@@ -2,10 +2,15 @@ package ephemeral
import (
"context"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
EphemeralLifeTime = 10 * time.Minute
)
type Manager interface {
LoadInitialPeers(ctx context.Context)
Stop()

View File

@@ -7,14 +7,15 @@ import (
log "github.com/sirupsen/logrus"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
)
const (
ephemeralLifeTime = 10 * time.Minute
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
cleanupWindow = 1 * time.Minute
)
@@ -33,11 +34,11 @@ type ephemeralPeer struct {
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
type EphemeralManager struct {
store store.Store
accountManager nbAccount.Manager
store store.Store
peersManager peers.Manager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
@@ -49,12 +50,12 @@ type EphemeralManager struct {
}
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager {
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
store: store,
accountManager: accountManager,
store: store,
peersManager: peersManager,
lifeTime: ephemeralLifeTime,
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
}
}
@@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period.
// is inactive it will be deleted after the EphemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
@@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock()
bufferAccountCall := make(map[string]struct{})
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} else {
bufferAccountCall[p.accountID] = struct{}{}
}
}
for accountID := range bufferAccountCall {
e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
}
}
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {

View File

@@ -7,10 +7,13 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
@@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) {
}
store := &MockStore{}
am := MockAccountManager{
store: store,
}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, &am)
// Expect DeletePeers to be called for ephemeral peers
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1)
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
@@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) {
}
store := &MockStore{}
am := MockAccountManager{
store: store,
}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, &am)
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
@@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
store := &MockStore{}
am := MockAccountManager{
store: store,
}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, &am)
// Expect DeletePeers to be called for the one disconnected peer
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
@@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
@@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mockStore := &MockStore{}
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
wg := &sync.WaitGroup{}
wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
wg: wg,
}
mockAM.wg = &sync.WaitGroup{}
mockAM.wg.Add(ephemeralPeers)
mgr := NewEphemeralManager(mockStore, mockAM)
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
// Set up expectation that DeletePeers will be called once with all peer IDs
peersManager.EXPECT().
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
// Simulate the actual deletion behavior
for _, peerID := range peerIDs {
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
return err
}
}
mockAM.BufferUpdateAccountPeers(ctx, accountID)
return nil
}).
Times(1)
mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
// Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
time.Sleep(testCleanupWindow / ephemeralPeers)
mgr.OnPeerDisconnected(context.Background(), p)
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
mockAM.wg.Wait()
// Advance time past the lifetime to trigger cleanup
startTime = startTime.Add(testLifeTime + testCleanupWindow)
// Wait for all deletions to complete
wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")

View File

@@ -0,0 +1,162 @@
package peers
//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"fmt"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error
SetNetworkMapController(networkMapController network_map.Controller)
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
accountManager account.Manager
networkMapController network_map.Controller
}
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) {
m.networkMapController = networkMapController
}
func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
m.integratedPeerValidator = integratedPeerValidator
}
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
m.accountManager = accountManager
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
dnsDomain := m.networkMapController.GetDNSDomain(settings)
for _, peerID := range peerIDs {
var eventsToStore []func()
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
return nil
}
if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil {
return fmt.Errorf("failed to remove peer %s from groups", peerID)
}
if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
return err
}
peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
for _, rule := range peerPolicyRules {
policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID)
if err != nil {
return err
}
err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID)
if err != nil {
return err
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
})
}
if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil {
return err
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
return nil
})
if err != nil {
return err
}
for _, event := range eventsToStore {
event()
}
}
return nil
}

View File

@@ -9,6 +9,9 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map"
account "github.com/netbirdio/netbird/management/server/account"
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// DeletePeers mocks base method.
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected)
ret0, _ := ret[0].(error)
return ret0
}
// DeletePeers indicates an expected call of DeletePeers.
func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected)
}
// GetAllPeers mocks base method.
func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
}
// SetAccountManager mocks base method.
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAccountManager", accountManager)
}
// SetAccountManager indicates an expected call of SetAccountManager.
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
}
// SetIntegratedPeerValidator mocks base method.
func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator)
}
// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator.
func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator)
}
// SetNetworkMapController mocks base method.
func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetworkMapController", networkMapController)
}
// SetNetworkMapController indicates an expected call of SetNetworkMapController.
func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}

View File

@@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics {
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
if err != nil {
log.Fatalf("failed to create store: %v", err)
}
@@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
if s.config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update config with activity store key")
s.config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
if s.Config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update Config with activity store key")
s.Config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
if err != nil {
log.Fatalf("failed to update config with activity store: %v", err)
log.Fatalf("failed to update Config with activity store: %v", err)
}
}
@@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler {
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.config.ReverseProxy.TrustedPeers
trustedPeers := s.Config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
@@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
if err != nil {
log.Fatalf("cannot load TLS credentials: %v", err)
}
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}

View File

@@ -9,17 +9,17 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
return Create(s, func() *update_channel.PeersUpdateManager {
return Create(s, func() network_map.PeersUpdateManager {
return update_channel.NewPeersUpdateManager(s.Metrics())
})
}
@@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
})
}
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
func (s *BaseServer) SecretsManager() grpc.SecretsManager {
return Create(s, func() grpc.SecretsManager {
secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager())
if err != nil {
log.Fatalf("failed to create secrets manager: %v", err)
}
return secretsManager
})
}
func (s *BaseServer) AuthManager() auth.Manager {
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
s.config.HttpConfig.AuthIssuer,
s.config.HttpConfig.AuthAudience,
s.config.HttpConfig.AuthKeysLocation,
s.config.HttpConfig.AuthUserIDClaim,
s.config.GetAuthAudiences(),
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
s.Config.HttpConfig.AuthIssuer,
s.Config.HttpConfig.AuthAudience,
s.Config.HttpConfig.AuthKeysLocation,
s.Config.HttpConfig.AuthUserIDClaim,
s.Config.GetAuthAudiences(),
s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
})
}
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
})
}
func (s *BaseServer) NetworkMapController() network_map.Controller {
return Create(s, func() *nmapcontroller.Controller {
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config)
return Create(s, func() network_map.Controller {
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config)
})
}
@@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())
})
}
func (s *BaseServer) DNSDomain() string {
return s.dnsDomain
}

View File

@@ -2,10 +2,12 @@ package server
import (
"context"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -14,20 +16,29 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
const (
geolocationDisabledKey = "NB_DISABLE_GEOLOCATION"
)
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
if os.Getenv(geolocationDisabledKey) == "true" {
log.Info("geolocation service is disabled, skipping initialization")
return nil
}
return Create(s, func() geolocation.Geolocation {
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate)
if err != nil {
log.Fatalf("could not initialize geolocation service: %v", err)
}
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
log.Infof("geolocation service has been initialized from %s", s.Config.Datadir)
return geo
})
@@ -60,20 +71,22 @@ func (s *BaseServer) SettingsManager() settings.Manager {
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
return peers.NewManager(s.Store(), s.PermissionsManager())
manager := peers.NewManager(s.Store(), s.PermissionsManager())
s.AfterInit(func(s *BaseServer) {
manager.SetNetworkMapController(s.NetworkMapController())
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
manager.SetAccountManager(s.AccountManager())
})
return manager
})
}
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetEphemeralManager(s.EphemeralManager())
})
return accountManager
})
}
@@ -82,8 +95,8 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
if s.config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
}

View File

@@ -41,10 +41,10 @@ type Server interface {
}
// Server holds the HTTP BaseServer instance.
// Add any additional fields you need, such as database connections, config, etc.
// Add any additional fields you need, such as database connections, Config, etc.
type BaseServer struct {
// config holds the server configuration
config *nbconfig.Config
// Config holds the server configuration
Config *nbconfig.Config
// container of dependencies, each dependency is identified by a unique string.
container map[string]any
// AfterInit is a function that will be called after the server is initialized
@@ -70,7 +70,7 @@ type BaseServer struct {
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
return &BaseServer{
config: config,
Config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
@@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
var tlsConfig *tls.Config
tlsEnabled := false
if s.config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if s.Config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
tlsEnabled = true
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
@@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
if !s.disableMetrics {
idpManager := "disabled"
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
idpManager = s.config.IdpManagerConfig.ManagerType
if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" {
idpManager = s.Config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
go metricsWorker.Run(srvCtx)

View File

@@ -104,6 +104,20 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
}
}
// ToSkipSyncResponse creates a minimal SyncResponse when the client already has the latest network map.
func ToSkipSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, checks []*posture.Checks, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{
SkipNetworkMapUpdate: true,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
return response
}
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -55,15 +54,12 @@ const (
type Server struct {
accountManager account.Manager
settingsManager settings.Manager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager network_map.PeersUpdateManager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager ephemeral.Manager
peerLocks sync.Map
authManager auth.Manager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
peerLocks sync.Map
authManager auth.Manager
logBlockedPeers bool
blockPeersWithSameConfig bool
@@ -82,23 +78,16 @@ func NewServer(
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager network_map.PeersUpdateManager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager ephemeral.Manager,
authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
networkMapController network_map.Controller,
) (*Server, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(peersUpdateManager.CountStreams())
err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(networkMapController.CountStreams())
})
if err != nil {
return nil, err
@@ -120,16 +109,12 @@ func NewServer(
}
return &Server{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
authManager: authManager,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
@@ -149,10 +134,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
}
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
@@ -163,8 +144,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
nanos := int32(now.Nanosecond())
expiresAt := &timestamp.Timestamp{Seconds: secs, Nanos: nanos}
key, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err)
return nil, errors.New("failed to get wireguard key")
}
return &proto.ServerKeyResponse{
Key: s.wgKey.PublicKey().String(),
Key: key.PublicKey().String(),
ExpiresAt: expiresAt,
}, nil
}
@@ -203,7 +190,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
@@ -231,8 +218,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err
}
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
@@ -244,7 +229,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
@@ -255,7 +239,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncReq.GetNetworkMapSerial())
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
@@ -269,9 +253,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err
}
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(ctx, peer)
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
@@ -323,13 +311,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
@@ -348,11 +342,10 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
}
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -504,7 +497,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage,
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
key, err := s.secretsManager.GetWGKey()
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request")
}
err = encryption.DecryptMessage(peerKey, key, req.Body, parsed)
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
}
@@ -520,7 +518,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
reqStart := time.Now()
realIP := getRealIP(ctx)
sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -532,7 +529,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
@@ -556,16 +553,12 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}()
if loginReq.GetMeta() == nil {
@@ -599,30 +592,26 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
key, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
@@ -713,19 +702,27 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
var plainResp *proto.SyncResponse
if networkMap == nil {
plainResp = ToSkipSyncResponse(ctx, s.config, peer, turnToken, relayToken, postureChecks, settings.Extra, peerGroups)
} else {
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
key, err := s.secretsManager.GetWGKey()
if err != nil {
return status.Errorf(codes.Internal, "failed getting server key")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp)
if err != nil {
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
@@ -740,10 +737,6 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
// which will be used by our clients to Login
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -752,7 +745,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
key, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get server key")
}
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
@@ -782,13 +780,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
},
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
@@ -798,10 +796,6 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
// which will be used by our clients to Login
func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
@@ -810,7 +804,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
key, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get server key")
}
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
@@ -838,13 +837,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}

View File

@@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
wgKey: testingServerKey,
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
key, err := mgmtServer.secretsManager.GetWGKey()
require.NoError(t, err, "should be able to get server key")
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
@@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)

View File

@@ -10,6 +10,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -29,6 +30,7 @@ type SecretsManager interface {
GenerateRelayToken() (*Token, error)
SetupRefresh(ctx context.Context, accountID, peerKey string)
CancelRefresh(peerKey string)
GetWGKey() (wgtypes.Key, error)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
@@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct {
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
wgKey wgtypes.Key
}
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
wgKey: key,
}
if turnCfg != nil {
@@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
}
}
return mgr
return mgr, nil
}
// GetWGKey returns WireGuard private key used to generate peer keys
func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) {
return m.wgKey, nil
}
// GenerateTurnToken generates new time-based secret credentials for TURN
@@ -153,7 +167,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountI
relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
log.WithContext(ctx).Tracef("starting relay refresh for %s", peerID)
}
}
@@ -164,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, acc
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
log.WithContext(ctx).Tracef("stopping TURN refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
@@ -179,7 +193,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, ac
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
log.WithContext(ctx).Tracef("stopping relay refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID)

View File

@@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
@@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {

View File

@@ -37,7 +37,6 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -77,7 +76,6 @@ type DefaultAccountManager struct {
ctx context.Context
eventStore activity.Store
geo geolocation.Geolocation
ephemeralManager ephemeral.Manager
requestBuffer *AccountRequestBuffer
@@ -238,7 +236,7 @@ func BuildManager(
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
}
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
if err != nil {
return nil, fmt.Errorf("getting cache store: %s", err)
}
@@ -263,10 +261,6 @@ func BuildManager(
return am, nil
}
func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
am.ephemeralManager = em
}
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
return am.externalCacheManager
}
@@ -301,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to approve pending peers: %w", err)
}
if approvedCount > 0 {
log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID)
updateAccountPeers = true
}
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
@@ -378,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -392,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
@@ -793,6 +790,13 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
if ctx == nil {
ctx = context.Background()
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err
@@ -1613,8 +1617,8 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, NetworkMapSerial: clientSerial}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
@@ -2076,7 +2080,10 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return nil
}

View File

@@ -13,7 +13,6 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -108,7 +107,7 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@@ -124,5 +123,4 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
}

View File

@@ -25,6 +25,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -2056,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
accountID := account.Id
userID := account.Users[account.CreatedBy].Id
ctx := context.Background()
newSettings := account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: true,
}
_, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
peer1.Status.RequiresApproval = true
peer2.Status.RequiresApproval = true
peer3.Status.RequiresApproval = false
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2))
require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3))
newSettings = account.Settings.Copy()
newSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: false,
}
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings)
require.NoError(t, err)
accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range accountPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID)
}
}
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string
@@ -2959,8 +2998,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
@@ -3105,7 +3144,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, 0)
assert.NoError(b, err)
}
@@ -3371,7 +3410,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
t.Run("memory cache", func(t *testing.T) {
t.Run("should always return true", func(t *testing.T) {
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
cold, err := manager.isCacheCold(context.Background(), cacheStore)
@@ -3386,7 +3425,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
t.Run("should return true when no account exists", func(t *testing.T) {

View File

@@ -179,6 +179,7 @@ const (
PeerIPUpdated Activity = 88
UserApproved Activity = 89
UserRejected Activity = 90
UserCreated Activity = 91
AccountDeleted Activity = 99999
)
@@ -288,6 +289,7 @@ var activityMap = map[Activity]Code{
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
UserApproved: {"User approved", "user.approve"},
UserRejected: {"User rejected", "user.reject"},
UserCreated: {"User created", "user.create"},
}
// StringCode returns a string code of the activity

View File

@@ -18,6 +18,7 @@ const (
DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days
DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days
DefaultIDPCacheCleanupInterval = 30 * time.Minute
DefaultIDPCacheOpenConn = 100
)
// UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects.

View File

@@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) {
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
}
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn)
if err != nil {
t.Fatalf("couldn't create cache store: %s", err)
}

View File

@@ -3,6 +3,7 @@ package cache
import (
"context"
"fmt"
"math"
"os"
"time"
@@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS"
// NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar
// to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store.
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) {
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) {
redisAddr := os.Getenv(RedisStoreEnvVar)
if redisAddr != "" {
return getRedisStore(ctx, redisAddr)
return getRedisStore(ctx, redisAddr, maxConn)
}
goc := gocache.New(maxTimeout, cleanupInterval)
return gocache_store.NewGoCache(goc), nil
}
func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) {
func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) {
options, err := redis.ParseURL(redisEnvAddr)
if err != nil {
return nil, fmt.Errorf("parsing redis cache url: %s", err)
}
options.MaxIdleConns = 6
options.MinIdleConns = 3
options.MaxActiveConns = 100
options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns
options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns
options.MaxActiveConns = maxConn
options.ConnMaxIdleTime = 30 * time.Minute
options.ConnMaxLifetime = 0
options.PoolTimeout = 10 * time.Second
redisClient := redis.NewClient(options)
subCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()

View File

@@ -15,7 +15,7 @@ import (
)
func TestMemoryStore(t *testing.T) {
memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("couldn't create memory store: %s", err)
}
@@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) {
func TestRedisStoreConnectionFailure(t *testing.T) {
t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379")
_, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond)
_, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100)
if err == nil {
t.Fatal("getting redis cache store should return error")
}
@@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) {
}
t.Setenv(cache.RedisStoreEnvVar, redisURL)
redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("couldn't create redis store: %s", err)
}

View File

@@ -12,6 +12,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -223,7 +225,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
@@ -39,7 +40,6 @@ import (
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
nbpeers "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -105,6 +105,7 @@ func NewAPIHandler(
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
appMetrics.GetMeter(),
)
corsMiddleware := cors.AllowAll()

View File

@@ -48,6 +48,29 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
}
accountID, userID := userAuth.AccountId, userAuth.UserId
// Check if filtering by name
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// Return as array with single element to maintain API consistency
groupsResponse := []*api.Group{toGroupResponse(accountPeers, group)}
util.WriteJSONObject(r.Context(), w, groupsResponse)
return
}
// Get all groups
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)

View File

@@ -60,12 +60,23 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return group, nil
},
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
groups := []*types.Group{
{ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT},
{ID: "id-existed", Name: "Existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI},
{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI},
}
groups = append(groups, initGroups...)
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}
return nil, fmt.Errorf("unknown group name")
return nil, status.Errorf(status.NotFound, "unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
@@ -287,6 +298,84 @@ func TestWriteGroup(t *testing.T) {
}
}
func TestGetAllGroups(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
expectedCount int
}{
{
name: "Get All Groups",
expectedBody: true,
requestType: http.MethodGet,
requestPath: "/api/groups",
expectedStatus: http.StatusOK,
expectedCount: 3, // id-jwt-group, id-existed, id-all
},
{
name: "Get Group By Name - Existing",
expectedBody: true,
requestType: http.MethodGet,
requestPath: "/api/groups?name=All",
expectedStatus: http.StatusOK,
expectedCount: 1,
},
{
name: "Get Group By Name - Not Found",
expectedBody: false,
requestType: http.MethodGet,
requestPath: "/api/groups?name=NonExistent",
expectedStatus: http.StatusNotFound,
},
}
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
return
}
if !tc.expectedBody {
return
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var groups []api.Group
if err = json.Unmarshal(content, &groups); err != nil {
t.Fatalf("Response is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedCount, len(groups))
})
}
}
func TestDeleteGroup(t *testing.T) {
tt := []struct {
name string

View File

@@ -45,19 +45,6 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
}
}
func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
peerToReturn := peer.Copy()
if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if !h.networkMapController.IsConnected(peer.ID) {
peerToReturn.Status.Connected = false
}
}
return peerToReturn, nil
}
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
@@ -65,11 +52,6 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(ctx, err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
@@ -91,7 +73,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -237,13 +219,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)

View File

@@ -109,14 +109,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
GetDNSDomain(gomock.Any()).
Return("domain").
AnyTimes()
networkMapController.EXPECT().
IsConnected(noUpdateChannelTestPeerID).
Return(false).
AnyTimes()
networkMapController.EXPECT().
IsConnected(gomock.Any()).
Return(true).
AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
@@ -269,14 +261,6 @@ func TestGetPeers(t *testing.T) {
expectedArray: false,
expectedPeer: peer,
},
{
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
},
{
name: "PutPeer",
requestType: http.MethodPut,
@@ -336,8 +320,6 @@ func TestGetPeers(t *testing.T) {
for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
}
}
@@ -351,14 +333,14 @@ func TestGetPeers(t *testing.T) {
t.Log(got)
assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber)
assert.Equal(t, tc.expectedPeer.Name, got.Name)
assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version)
assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip)
assert.Equal(t, "OS core", got.Os)
assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled)
assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled)
assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected)
assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber)
})
}
}

View File

@@ -9,6 +9,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -31,6 +32,7 @@ type AuthMiddleware struct {
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
patUsageTracker *PATUsageTracker
}
// NewAuthMiddleware instance constructor
@@ -40,18 +42,29 @@ func NewAuthMiddleware(
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
meter metric.Meter,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker
if meter != nil {
var err error
patUsageTracker, err = NewPATUsageTracker(context.Background(), meter)
if err != nil {
log.Errorf("Failed to create PAT usage tracker: %s", err)
}
}
return &AuthMiddleware{
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
patUsageTracker: patUsageTracker,
}
}
@@ -128,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Debugf("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
@@ -158,6 +171,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return r, fmt.Errorf("error extracting token: %w", err)
}
if m.patUsageTracker != nil {
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")

View File

@@ -208,6 +208,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
return &types.User{}, nil
},
nil,
nil,
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -266,6 +267,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -317,6 +319,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -359,6 +362,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -402,6 +406,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -465,6 +470,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -581,6 +587,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
return &types.User{}, nil
},
nil,
nil,
)
for _, tc := range tt {

View File

@@ -0,0 +1,87 @@
package middleware
import (
"context"
"maps"
"sync"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
)
// PATUsageTracker tracks PAT usage metrics
type PATUsageTracker struct {
usageCounters map[string]int64
mu sync.Mutex
stopChan chan struct{}
ctx context.Context
histogram metric.Int64Histogram
}
// NewPATUsageTracker creates a new PAT usage tracker with metrics
func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) {
histogram, err := meter.Int64Histogram(
"management.pat.usage_distribution",
metric.WithUnit("1"),
metric.WithDescription("Distribution of PAT token usage counts per minute"),
)
if err != nil {
return nil, err
}
tracker := &PATUsageTracker{
usageCounters: make(map[string]int64),
stopChan: make(chan struct{}),
ctx: ctx,
histogram: histogram,
}
go tracker.reportLoop()
return tracker, nil
}
// IncrementUsage increments the usage counter for a given token
func (t *PATUsageTracker) IncrementUsage(token string) {
t.mu.Lock()
defer t.mu.Unlock()
t.usageCounters[token]++
}
// reportLoop reports the usage buckets every minute
func (t *PATUsageTracker) reportLoop() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
t.reportUsageBuckets()
case <-t.stopChan:
return
}
}
}
// reportUsageBuckets reports all token usage counts and resets counters
func (t *PATUsageTracker) reportUsageBuckets() {
t.mu.Lock()
snapshot := maps.Clone(t.usageCounters)
clear(t.usageCounters)
t.mu.Unlock()
totalTokens := len(snapshot)
if totalTokens > 0 {
for _, count := range snapshot {
t.histogram.Record(t.ctx, count)
}
log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens)
}
}
// Stop stops the reporting goroutine
func (t *PATUsageTracker) Stop() {
close(t.stopChan)
}

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server"
@@ -28,7 +30,6 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -72,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)

View File

@@ -127,7 +127,7 @@ type MockIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error {
return nil
}

View File

@@ -10,7 +10,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)

View File

@@ -24,13 +24,14 @@ import (
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -363,7 +364,9 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
@@ -372,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
cleanup()
return nil, nil, "", cleanup, err
}
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, nil, "", cleanup, err
}

View File

@@ -22,13 +22,14 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -205,7 +206,7 @@ func startServer(
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config)
accountManager, err := server.BuildManager(
context.Background(),
@@ -228,15 +229,16 @@ func startServer(
}
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatalf("failed creating secrets manager: %v", err)
}
mgmtServer, err := nbgrpc.NewServer(
config,
accountManager,
settingsMockManager,
updateManager,
secretsManager,
nil,
&manager.EphemeralManager{},
nil,
server.MockIntegratedValidator{},
networkMapController,

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -38,7 +37,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -178,9 +177,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, clientSerial)
}
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -976,11 +975,6 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth a
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}
// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface
func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) {
// Mock implementation - does nothing
}
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
if am.AllowSyncFunc != nil {
return am.AllowSyncFunc(key, hash)

View File

@@ -13,6 +13,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -792,7 +794,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}

View File

@@ -136,7 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return nil
@@ -169,7 +172,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
}
}
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil {
@@ -309,7 +312,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
return peer, nil
}
@@ -365,13 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
return nil
@@ -583,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if temporary {
// we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually
am.ephemeralManager.OnPeerDisconnected(ctx, newPeer)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -645,11 +641,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil {
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer, 0)
return p, nmap, pc, err
}
@@ -729,10 +725,13 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer, sync.NetworkMapSerial)
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -784,7 +783,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
@@ -854,13 +852,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer, 0)
return p, nmap, pc, err
}

View File

@@ -28,6 +28,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
@@ -1058,6 +1060,7 @@ func testUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}
@@ -1290,7 +1293,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1375,7 +1378,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1528,7 +1531,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1608,7 +1611,7 @@ func Test_LoginPeer(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)

View File

@@ -1,68 +0,0 @@
package peers
//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
}
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}

View File

@@ -82,7 +82,7 @@ func (c *OSVersionCheck) Validate() error {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
@@ -107,7 +107,7 @@ func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *M
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil {
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Tracef("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}

View File

@@ -16,6 +16,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -1291,7 +1293,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {

View File

@@ -5,9 +5,6 @@ package settings
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -48,11 +45,6 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
}
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {

View File

@@ -27,7 +27,6 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -288,7 +287,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds())
log.WithContext(ctx).Tracef("took %d ms to delete an account to the store", took.Milliseconds())
return err
}
@@ -413,6 +412,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW
return nil
}
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true).
Update("peer_status_requires_approval", false)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error)
}
return int(result.RowsAffected), nil
}
// SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 {
@@ -583,16 +594,13 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -2152,16 +2160,13 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
@@ -2171,16 +2176,13 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -2229,11 +2231,8 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
@@ -2491,16 +2490,13 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey
result := tx.WithContext(ctx).
result := tx.
Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil {
@@ -2514,10 +2510,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
result := s.db.Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@@ -2537,11 +2530,8 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string
_ = s.db.WithContext(ctx).Model(types.Group{}).
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
@@ -2569,9 +2559,6 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
@@ -2768,10 +2755,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -2897,10 +2881,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -4022,36 +4003,6 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil
}
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
@@ -4091,7 +4042,7 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
Network: &types.Network{Net: ipNet},
}
result := s.db.WithContext(ctx).
result := s.db.
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)

View File

@@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
})
}
}
func TestSqlStore_ApproveAccountPeers(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
accountID := "test-account"
ctx := context.Background()
account := newAccountWithId(ctx, accountID, "testuser", "example.com")
err := store.SaveAccount(ctx, account)
require.NoError(t, err)
peers := []*nbpeer.Peer{
{
ID: "peer1",
AccountID: accountID,
DNSLabel: "peer1.netbird.cloud",
Key: "peer1-key",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer2",
AccountID: accountID,
DNSLabel: "peer2.netbird.cloud",
Key: "peer2-key",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{
RequiresApproval: true,
LastSeen: time.Now().UTC(),
},
},
{
ID: "peer3",
AccountID: accountID,
DNSLabel: "peer3.netbird.cloud",
Key: "peer3-key",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{
RequiresApproval: false,
LastSeen: time.Now().UTC(),
},
},
}
for _, peer := range peers {
err = store.AddPeerToAccount(ctx, peer)
require.NoError(t, err)
}
t.Run("approve all pending peers", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 2, count)
allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "")
require.NoError(t, err)
for _, peer := range allPeers {
assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID)
}
})
t.Run("no peers to approve", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, accountID)
require.NoError(t, err)
assert.Equal(t, 0, count)
})
t.Run("non-existent account", func(t *testing.T) {
count, err := store.ApproveAccountPeers(ctx, "non-existent")
require.NoError(t, err)
assert.Equal(t, 0, count)
})
})
}

View File

@@ -143,6 +143,7 @@ type Store interface {
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)

View File

@@ -16,7 +16,6 @@ type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
@@ -46,14 +45,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
)
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -126,7 +117,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
@@ -175,9 +165,6 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -185,6 +185,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
h.ServeHTTP(w, r.WithContext(ctx))
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
} else {

View File

@@ -15,6 +15,9 @@ type PeerSync struct {
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool
// NetworkMapSerial is the last known network map serial number on the client.
// Used to skip network map recalculation if client already has the latest.
NetworkMapSerial uint64
}
// PeerLogin used as a data object between the gRPC API and Manager on Login request.

View File

@@ -263,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return err
}
updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
_, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
@@ -596,9 +592,15 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() {
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, isNewUser bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() {
var eventsToStore []func()
if isNewUser {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.UserCreated, nil)
})
}
if oldUser.IsBlocked() != newUser.IsBlocked() {
if newUser.IsBlocked() {
eventsToStore = append(eventsToStore, func() {
@@ -661,7 +663,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
oldUser, isNewUser, err := getUserOrCreateIfNotExists(ctx, transaction, accountID, update, addIfNotExists)
if err != nil {
return false, nil, nil, nil, err
}
@@ -716,30 +718,30 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction)
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, bool, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
return nil, false, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
}
update.AccountID = accountID
return update, nil // use all fields from update if addIfNotExists is true
return update, true, nil // use all fields from update if addIfNotExists is true
}
return nil, err
return nil, false, err
}
if existingUser.AccountID != accountID {
return nil, status.Errorf(status.InvalidArgument, "user account ID mismatch")
return nil, false, status.Errorf(status.InvalidArgument, "user account ID mismatch")
}
return existingUser, nil
return existingUser, false, nil
}
func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initiatorUser, update *types.User) (bool, error) {
@@ -992,14 +994,17 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peer.UserID, peer.ID, accountID,
activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
)
}
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
am.networkMapController.DisconnectPeers(ctx, peerIDs)
am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs)
}
return nil
}
@@ -1045,7 +1050,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
var allErrors error
var updateAccountPeers bool
for _, targetUserID := range targetUserIDs {
if initiatorUserID == targetUserID {
@@ -1076,19 +1080,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
_, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if userHadPeers {
updateAccountPeers = true
}
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return allErrors
@@ -1146,15 +1142,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return false, err
}
var peerIDs []string
for _, peer := range userPeers {
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err)
}
peerIDs = append(peerIDs, peer.ID)
}
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
}
for _, addPeerRemovedEvent := range addPeerRemovedEvents {

View File

@@ -8,8 +8,10 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) {
permissionsManager: permissionsManager,
}
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
require.NoError(t, err)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs)
@@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
networkMapController: networkMapControllerMock,
}
testCases := []struct {
@@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager,
networkMapController: networkMapControllerMock,
}
testCases := []struct {
@@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
permissionsManager: permissionsManager,
}
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
assert.NoError(t, err)
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
@@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -160,7 +160,8 @@ func execute(cmd *cobra.Command, args []string) error {
log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err)
}
log.Infof("server will be available on: %s", srv.InstanceURL())
instanceURL := srv.InstanceURL()
log.Infof("server will be available on: %s", instanceURL.String())
wg.Add(1)
go func() {
defer wg.Done()

View File

@@ -6,14 +6,14 @@ import (
"errors"
"net"
"net/http"
"net/url"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
"github.com/netbirdio/netbird/relay/server"
)
const (
@@ -27,7 +27,7 @@ const (
type ServiceChecker interface {
ListenerProtocols() []protocol.Protocol
ListenAddress() string
InstanceURL() url.URL
}
type HealthStatus struct {
@@ -135,7 +135,11 @@ func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
}
status.Listeners = listeners
if ok := s.validateCertificate(ctx); !ok {
if s.config.ServiceChecker.InstanceURL().Scheme != server.SchemeRELS {
status.CertificateValid = false
}
if ok := s.validateConnection(ctx); !ok {
status.Status = statusUnhealthy
status.CertificateValid = false
healthy = false
@@ -152,32 +156,13 @@ func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
return listeners, true
}
func (s *Server) validateCertificate(ctx context.Context) bool {
listenAddress := s.config.ServiceChecker.ListenAddress()
if listenAddress == "" {
log.Warn("listen address is empty")
func (s *Server) validateConnection(ctx context.Context) bool {
addr := s.config.ServiceChecker.InstanceURL()
if err := dialWS(ctx, addr); err != nil {
log.Errorf("failed to dial WebSocket listener at %s: %v", addr.String(), err)
return false
}
dAddr := dialAddress(listenAddress)
for _, proto := range s.config.ServiceChecker.ListenerProtocols() {
switch proto {
case ws.Proto:
if err := dialWS(ctx, dAddr); err != nil {
log.Errorf("failed to dial WebSocket listener: %v", err)
return false
}
case quic.Proto:
if err := dialQUIC(ctx, dAddr); err != nil {
log.Errorf("failed to dial QUIC listener: %v", err)
return false
}
default:
log.Warnf("unknown protocol for healthcheck: %s", proto)
return false
}
}
return true
}
@@ -187,8 +172,9 @@ func dialAddress(listenAddress string) string {
return listenAddress // fallback, might be invalid for dialing
}
// When listening on all interfaces, show localhost for better readability
if host == "" || host == "::" || host == "0.0.0.0" {
host = "0.0.0.0"
host = "localhost"
}
return net.JoinHostPort(host, port)

Some files were not shown because too many files have changed in this diff Show More