Compare commits

..

4 Commits

Author SHA1 Message Date
Zoltán Papp
8ac5e9d866 Fix log 2024-10-31 19:07:38 +01:00
Zoltán Papp
954e038da0 Add more logs 2024-10-31 18:19:57 +01:00
Zoltán Papp
9ccc6c6547 Add nil value check 2024-10-31 16:48:10 +01:00
Zoltan Papp
2a3262f5a8 Print debug info 2024-10-29 13:54:35 +01:00
162 changed files with 3531 additions and 8930 deletions

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [netbirdio]

View File

@@ -13,7 +13,6 @@ concurrency:
jobs:
test:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
@@ -52,47 +51,6 @@ jobs:
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
benchmark:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.17"
SIGN_PIPE_VER: "v0.0.16"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -17,12 +17,8 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>
@@ -34,7 +30,7 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
<br/>
</strong>

View File

@@ -12,8 +12,6 @@ import (
"strings"
)
const anonTLD = ".domain"
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
@@ -85,39 +83,29 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
baseDomain := domain
hasDot := strings.HasSuffix(domain, ".")
if hasDot {
baseDomain = domain[:len(domain)-1]
}
if strings.HasSuffix(baseDomain, "netbird.io") ||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
strings.HasSuffix(baseDomain, "netbird.stage") ||
strings.HasSuffix(baseDomain, anonTLD) {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(baseDomain, ".")
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseForLookup]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
a.domainAnonymizer[baseForLookup] = anonymizedBase
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
if hasDot {
result += "."
}
return result
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
@@ -164,9 +152,9 @@ func (a *Anonymizer) AnonymizeString(str string) string {
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
@@ -180,10 +168,10 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, anonTLD) {
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + anonTLD
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
@@ -213,8 +201,6 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
}
if slices.Contains(wellKnown, addr.String()) {

View File

@@ -67,36 +67,18 @@ func TestAnonymizeDomain(t *testing.T) {
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Domain with Trailing Dot",
"example.com.",
`^anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain with Trailing Dot",
"sub.example.com.",
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
{
"Protected Domain with Trailing Dot",
"netbird.io.",
`^netbird\.io.$`,
false,
},
}
for _, tc := range tests {
@@ -158,16 +140,8 @@ func TestAnonymizeSchemeURI(t *testing.T) {
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
}
for _, tc := range tests {

View File

@@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -62,15 +61,6 @@ var forCmd = &cobra.Command{
RunE: runForDuration,
}
var persistenceCmd = &cobra.Command{
Use: "persistence [on|off]",
Short: "Set network map memory persistence",
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
Example: " netbird debug persistence on",
Args: cobra.ExactArgs(1),
RunE: setNetworkMapPersistence,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
@@ -181,13 +171,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(1 * time.Second)
// Enable network map persistence before bringing the service up
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
@@ -217,13 +200,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -243,34 +219,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return nil
}
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
persistence := strings.ToLower(args[0])
if persistence != "on" && persistence != "off" {
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
}
client := proto.NewDaemonServiceClient(conn)
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: persistence == "on",
})
if err != nil {
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
}
cmd.Printf("Network map persistence set to: %s\n", persistence)
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())

View File

@@ -1,33 +0,0 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -155,7 +155,6 @@ func init() {
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+

View File

@@ -2,7 +2,6 @@ package cmd
import (
"context"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
@@ -14,11 +13,10 @@ import (
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,9 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
@@ -74,7 +72,6 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
@@ -82,7 +79,6 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel()

View File

@@ -1,181 +0,0 @@
package cmd
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var (
allFlag bool
)
var stateCmd = &cobra.Command{
Use: "state",
Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
}
var stateListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all stored states",
Long: "Lists all registered states with their status and basic information.",
Example: " netbird state list",
RunE: stateList,
}
var stateCleanCmd = &cobra.Command{
Use: "clean [state-name]",
Short: "Clean stored states",
Long: `Clean specific state or all states. The daemon must not be running.
This will perform cleanup operations and remove the state.`,
Example: ` netbird state clean dns_state
netbird state clean --all`,
RunE: stateClean,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
var stateDeleteCmd = &cobra.Command{
Use: "delete [state-name]",
Short: "Delete stored states",
Long: `Delete specific state or all states from storage. The daemon must not be running.
This will remove the state without performing any cleanup operations.`,
Example: ` netbird state delete dns_state
netbird state delete --all`,
RunE: stateDelete,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
func init() {
rootCmd.AddCommand(stateCmd)
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
}
func stateList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
if err != nil {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
}
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
}
if resp.CleanedStates == 0 {
cmd.Println("No states were cleaned")
return nil
}
if allFlag {
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
} else {
cmd.Printf("Successfully cleaned state %q\n", stateName)
}
return nil
}
func stateDelete(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
}
if resp.DeletedStates == 0 {
cmd.Println("No states were deleted")
return nil
}
if allFlag {
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
} else {
cmd.Printf("Successfully deleted state %q\n", stateName)
}
return nil
}

View File

@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := true
nameEval := false
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
@@ -700,13 +700,11 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval

View File

@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
position: 1,
},
}

View File

@@ -83,11 +83,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}

View File

@@ -18,24 +18,22 @@ import (
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
)
@@ -325,25 +323,24 @@ func (r *router) Reset() error {
}
func (r *router) cleanUpDefaultForwardRules() error {
if err := r.cleanJumpRules(); err != nil {
return fmt.Errorf("clean jump rules: %w", err)
err := r.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
} else if ok {
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
}
}
}
@@ -352,16 +349,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
func (r *router) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %w", chain, err)
}
}
@@ -369,10 +359,6 @@ func (r *router) createContainers() error {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
@@ -380,32 +366,6 @@ func (r *router) createContainers() error {
return nil
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
@@ -417,14 +377,10 @@ func (r *router) createAndSetupChain(chain string) error {
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
if chain == chainRTNAT {
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
@@ -442,39 +398,25 @@ func (r *router) insertEstablishedRule(chain string) error {
}
func (r *router) addJumpRules() error {
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
r.rules[ipv4Nat] = rule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
rule, found := r.rules[ipv4Nat]
if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
}
}
return nil
}
@@ -482,35 +424,19 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
@@ -518,12 +444,13 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("marking rule %s not found", ruleKey)
log.Debugf("nat rule %s not found", ruleKey)
}
return nil
@@ -555,6 +482,16 @@ func (r *router) updateState() {
}
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string

View File

@@ -3,18 +3,17 @@
package iptables
import (
"fmt"
"net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
)
func isIptablesSupported() bool {
@@ -35,24 +34,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
_ = manager.Reset()
}()
// Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
@@ -60,15 +49,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.AddNatRule(pair)
require.NoError(t, err, "adding NAT rule should not return error")
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -83,66 +79,52 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
require.NoError(t, err, "forwarding pair should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
require.True(t, found, "marking rule should exist in the map")
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "marking rule should not be created")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the map")
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -155,52 +137,42 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
_ = manager.Reset()
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the manager map")
require.False(t, found, "nat rule should exist in the manager map")
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}

View File

@@ -37,11 +37,6 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
@@ -94,10 +89,5 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -17,7 +17,6 @@ import (
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
)

View File

@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Verdict{
Kind: expr.VerdictJump,

View File

@@ -99,11 +99,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
@@ -199,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
chain = c
break
}
@@ -276,7 +274,7 @@ func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
@@ -351,9 +349,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
&expr.Verdict{},
},
UserData: []byte(allowNetbirdInputRuleID),
}

View File

@@ -1,11 +1,9 @@
package nftables
import (
"bytes"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
"time"
@@ -227,105 +225,3 @@ func TestNFtablesCreatePerformance(t *testing.T) {
})
}
}
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}

View File

@@ -21,7 +21,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -125,6 +124,7 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
@@ -133,21 +133,6 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
@@ -437,149 +422,59 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
op := expr.CmpOpEq
dir := expr.MetaKeyIIFNAME
notDir := expr.MetaKeyOIFNAME
if pair.Inverse {
op = expr.CmpOpNeq
dir = expr.MetaKeyOIFNAME
notDir = expr.MetaKeyIIFNAME
}
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
&expr.Meta{
Key: dir,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
Data: lo,
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
&expr.Counter{}, &expr.Masq{},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
return fmt.Errorf("remove routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNamePrerouting],
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
@@ -828,18 +723,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes the prerouting mark rule
// RemoveNatRule removes a nftables rule pair from nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -854,20 +749,21 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
log.Debugf("nftables: nat rule %s not found", ruleKey)
}
return nil

View File

@@ -10,7 +10,6 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,87 +32,100 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{}
rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
})
defer func(manager *router, pair firewall.RouterPair) {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
}(manager, testCase.InputPair)
if testCase.InputPair.Masquerade {
// Build expected expressions for connection tracking
conntrackExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
}
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
@@ -123,66 +135,68 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
manager, err := newRouter(table, ifaceMock)
require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.router
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.True(t, found, "NAT rule should exist before removal")
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
})
}
}

View File

@@ -0,0 +1 @@
package nftables

View File

@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return nil
return errRouteNotSupported
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
@@ -25,8 +24,8 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
// ICEBind is a bind implementation with two main features:
@@ -155,7 +154,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
@@ -167,30 +166,16 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
@@ -206,12 +191,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
@@ -289,15 +273,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -162,13 +162,12 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}

View File

@@ -2,7 +2,6 @@ package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
@@ -95,10 +94,7 @@ func (p *ProxyBind) close() error {
p.Bind.RemoveEndpoint(p.wgAddr)
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
return p.remoteConn.Close()
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
@@ -108,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
}()
buf := make([]byte, 1500)
for {
buf := make([]byte, 1500)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {

View File

@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel()
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil

View File

@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}

View File

@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
return util.WriteJson(path, config)
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
}
if updated {
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
}
}

View File

@@ -40,8 +40,6 @@ type ConnectClient struct {
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
persistNetworkMap bool
}
func NewConnectClient(
@@ -159,8 +157,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx)
defer func() {
_, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.MarkManagementDisconnected(state.err)
c.statusRecorder.CleanLocalPeerState()
cancel()
}()
@@ -210,8 +207,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(nil)
defer func() {
_, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
c.statusRecorder.MarkSignalDisconnected(state.err)
}()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@@ -234,7 +230,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 {
if token != nil {
if err := relayManager.UpdateToken(token); err != nil {
@@ -245,7 +240,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
}
c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig()
@@ -260,7 +257,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil {
@@ -338,19 +335,6 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
return StatusIdle
}
status, err := CtxGetState(c.ctx).Status()
if err != nil {
return StatusIdle
}
return status
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
@@ -377,22 +361,6 @@ func (c *ConnectClient) isContextCancelled() bool {
}
}
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored
// network map will be cleared. This functionality is primarily used for debugging
// and should not be enabled during normal operation.
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
c.engineMutex.Lock()
c.persistNetworkMap = enabled
c.engineMutex.Unlock()
engine := c.Engine()
if engine != nil {
engine.SetNetworkMapPersistence(enabled)
}
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false

View File

@@ -7,6 +7,7 @@ import (
"runtime"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
@@ -322,12 +323,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}
go func() {
// persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
@@ -532,11 +533,12 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
}()
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()

View File

@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53,
},
},
Domains: []string{"google.com"},
Domains: []string{"customdomain.com"},
Primary: false,
},
},
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
_, err = resolver.LookupHost(context.Background(), "google.com")
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}

View File

@@ -11,7 +11,6 @@ import (
"reflect"
"runtime"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
@@ -21,7 +20,6 @@ import (
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
@@ -40,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
@@ -172,11 +171,7 @@ type Engine struct {
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Network map persistence
persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap
srWatcher *guard.SRWatcher
}
// Peer is an instance of the Connection Peer
@@ -302,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(context.Background()); err != nil {
if err := e.stateManager.PersistState(ctx); err != nil {
log.Errorf("failed to persist state: %v", err)
}
@@ -354,17 +349,8 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(
e.ctx,
e.config.WgPrivateKey.PublicKey().String(),
e.config.DNSRouteInterval,
e.wgInterface,
e.statusRecorder,
e.relayManager,
initialRoutes,
e.stateManager,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
@@ -552,7 +538,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
relayMsg := wCfg.GetRelay()
if relayMsg != nil {
// when we receive token we expect valid address list too
c := &auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
@@ -561,16 +546,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(relayMsg.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
_ = e.relayManager.Serve()
} else {
e.relayManager.UpdateServerURLs(nil)
}
// todo update relay address in the relay manager
// todo update signal
}
@@ -578,22 +556,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return err
}
nm := update.GetNetworkMap()
if nm == nil {
return nil
if update.GetNetworkMap() != nil {
// only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap())
if err != nil {
return err
}
}
// Store network map if persistence is enabled
if e.persistNetworkMap {
e.latestNetworkMap = nm
log.Debugf("network map persisted with serial %d", nm.GetSerial())
}
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
return err
}
return nil
}
@@ -672,10 +641,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
@@ -1514,59 +1479,8 @@ func (e *Engine) stopDNSServer() {
e.statusRecorder.UpdateDNSStates(nsGroupStates)
}
// SetNetworkMapPersistence enables or disables network map persistence
func (e *Engine) SetNetworkMapPersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if enabled == e.persistNetworkMap {
return
}
e.persistNetworkMap = enabled
log.Debugf("Network map persistence is set to %t", enabled)
if !enabled {
e.latestNetworkMap = nil
}
}
// GetLatestNetworkMap returns the stored network map if persistence is enabled
func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if !e.persistNetworkMap {
return nil, errors.New("network map persistence is disabled")
}
if e.latestNetworkMap == nil {
//nolint:nilnil
return nil, nil
}
// Create a deep copy to avoid external modifications
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
if !ok {
return nil, fmt.Errorf("failed to clone network map")
}
return nm, nil
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})

View File

@@ -245,15 +245,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
nil)
wgIface := &iface.MockWGIface{
NameFunc: func() string { return "utun102" },
RemovePeerFunc: func(peerKey string) error {
return nil
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
@@ -1009,99 +1006,6 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
}
}
func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {

View File

@@ -83,6 +83,7 @@ type Conn struct {
signaler *Signaler
relayManager *relayClient.Manager
allowedIP net.IP
allowedNet string
handshaker *Handshaker
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
@@ -110,7 +111,7 @@ type Conn struct {
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err)
return nil, err
@@ -128,6 +129,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
signaler: signaler,
relayManager: relayManager,
allowedIP: allowedIP,
allowedNet: allowedNet.String(),
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
}
@@ -307,11 +309,6 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
return
}
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
conn.log.Errorf("remote ICE connection is nil")
return
}
conn.log.Debugf("ICE connection is ready")
if conn.currentConnPriority > priority {
@@ -336,9 +333,12 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
conn.log.Infof("direct iceConnInfo: %v", iceConnInfo.RemoteConn)
agentCheck(conn.log, iceConnInfo.Agent)
nilCheck(conn.log, iceConnInfo.RemoteConn)
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
return
}
@@ -440,7 +440,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.setRelayedProxy(wgProxy)
conn.wgProxyRelay = wgProxy
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
@@ -463,7 +463,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.setRelayedProxy(wgProxy)
conn.wgProxyRelay = wgProxy
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
@@ -592,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
}
if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
}
}
@@ -734,15 +734,6 @@ func (conn *Conn) logTraceConnState() {
}
}
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -2,20 +2,53 @@ package peer
import (
"net"
"reflect"
"github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
)
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
func nilCheck(log *log.Entry, conn net.Conn) {
if conn == nil {
log.Errorf("ice conn is nil")
return true
log.Infof("conn is nil")
return
}
if conn.RemoteAddr() == nil {
log.Errorf("ICE remote address is nil")
return true
log.Infof("conn.RemoteAddr() is nil")
return
}
return false
if reflect.ValueOf(conn.RemoteAddr()).IsNil() {
log.Infof("value of conn.RemoteAddr() is nil")
return
}
}
func agentCheck(log *log.Entry, agent *ice.Agent) {
if agent == nil {
log.Errorf("agent is nil")
return
}
pair, err := agent.GetSelectedCandidatePair()
if err != nil {
log.Errorf("error getting selected candidate pair: %v", err)
return
}
if pair == nil {
log.Errorf("pair is nil")
return
}
if pair.Remote == nil {
log.Errorf("pair.Remote is nil")
return
}
if pair.Remote.Address() == "" {
log.Errorf("address is empty")
return
}
}

View File

@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
func (s *State) GetRoutes() map[string]struct{} {
s.Mux.RLock()
defer s.Mux.RUnlock()
return maps.Clone(s.routes)
return s.routes
}
// LocalPeerState contains the latest state of the local peer
@@ -237,6 +237,10 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}
if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
@@ -257,40 +261,12 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
d.notifyPeerListChanged()
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
peerState.AddRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.DeleteRoute(route)
d.peers[peer] = peerState
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
@@ -325,7 +301,12 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
@@ -353,7 +334,12 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
@@ -380,7 +366,12 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
@@ -410,7 +401,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}
d.notifyPeerStateChangeListeners(receivedState.PubKey)
ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}
d.notifyPeerListChanged()
return nil
}
@@ -481,14 +477,11 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()
ch, found := d.changeNotify[peer]
if found {
return ch
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
}
ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}
@@ -676,23 +669,25 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates)
var relayState relay.ProbeResult
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
Err: err,
})
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
})
}
return relayStates
}
return relayStates
relayState.Err = err
}
relayState := relay.ProbeResult{
URI: instanceAddr,
}
relayState.URI = instanceAddr
return append(relayStates, relayState)
}
@@ -760,17 +755,6 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}
close(ch)
delete(d.changeNotify, peerID)
}
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}

View File

@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
peerState.IP = ip
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
err := status.UpdatePeerState(peerState)
assert.NoError(t, err, "shouldn't return error")
select {

View File

@@ -29,6 +29,7 @@ type ICEConnInfo struct {
LocalIceCandidateEndpoint string
Relayed bool
RelayedOnLocal bool
Agent *ice.Agent
}
type WorkerICECallbacks struct {
@@ -46,6 +47,8 @@ type WorkerICE struct {
hasRelayOnLocally bool
conn WorkerICECallbacks
selectedPriority ConnPriority
agent *ice.Agent
muxAgent sync.Mutex
@@ -55,9 +58,6 @@ type WorkerICE struct {
localUfrag string
localPwd string
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
@@ -93,8 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else {
w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = icemaker.CandidateTypes()
}
@@ -125,6 +127,9 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("failed to dial the remote peer: %s", err)
return
}
w.log.Infof("check remoteConn: %v", remoteConn)
w.log.Infof("check remoteConn.RemoteAddr: %v", remoteConn.RemoteAddr())
nilCheck(w.log, remoteConn)
w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair()
@@ -153,9 +158,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
Relayed: isRelayed(pair),
RelayedOnLocal: isRelayCandidate(pair.Local),
Agent: agent,
}
w.log.Debugf("on ICE conn read to use ready")
go w.conn.OnConnReady(selectedPriority(pair), ci)
go w.conn.OnConnReady(w.selectedPriority, ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -193,7 +199,8 @@ func (w *WorkerICE) Close() {
return
}
if err := w.agent.Close(); err != nil {
err := w.agent.Close()
if err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
@@ -213,18 +220,15 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected)
}
w.closeAgent(agentCancel)
default:
return
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)
w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil
w.muxAgent.Unlock()
}
})
if err != nil {
@@ -250,17 +254,6 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
@@ -334,8 +327,10 @@ func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
w.log.Infof("dialing remote peer %s as controlling", w.config.Key)
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
w.log.Infof("dialing remote peer %s as controlled", w.config.Key)
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
@@ -390,11 +385,3 @@ func isRelayed(pair *ice.CandidatePair) bool {
}
return false
}
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
if isRelayed(pair) {
return connPriorityICETurn
} else {
return connPriorityICEP2P
}
}

View File

@@ -122,20 +122,13 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10
}
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
log.Warnf("peer %s has 0 latency", r.Peer)
}
// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}
// higher latency is worse score
tempScore += 1 - latency.Seconds()
if !peerStatus.relayed {
@@ -157,8 +150,6 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
}
}
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
switch {
case chosen == "":
var peers []string
@@ -204,20 +195,15 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if found {
continue
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
}
closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
@@ -232,7 +218,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
var merr *multierror.Error
if err := c.removeRouteFromWireGuardPeer(); err != nil {
if err := c.removeRouteFromWireguardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
@@ -271,7 +257,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireGuardPeer(); err != nil {
if err := c.removeRouteFromWireguardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
@@ -282,13 +268,37 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
return fmt.Errorf("add peer state route: %w", err)
}
c.addStateRoute()
return nil
}
func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}
state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
go func() {
c.routeUpdate <- update

View File

@@ -1,7 +1,6 @@
package routemanager
import (
"fmt"
"net/netip"
"testing"
"time"
@@ -228,64 +227,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
@@ -346,45 +287,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
},
}
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{

View File

@@ -32,7 +32,7 @@ import (
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
@@ -59,7 +59,6 @@ type DefaultManager struct {
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
stateManager *statemanager.Manager
}
func NewManager(
@@ -70,7 +69,6 @@ func NewManager(
statusRecorder *peer.Status,
relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
stateManager *statemanager.Manager,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier()
@@ -82,12 +80,12 @@ func NewManager(
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: notifier,
stateManager: stateManager,
}
dm.routeRefCounter = refcounter.New(
@@ -123,7 +121,7 @@ func NewManager(
}
// Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
if nbnet.CustomRoutingDisabled() {
return nil, nil, nil
}
@@ -139,38 +137,14 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
}
m.routeSelector = m.initSelector()
log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil
}
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
var state *SelectorState
m.stateManager.RegisterState(state)
// restore selector state if it exists
if err := m.stateManager.LoadState(state); err != nil {
log.Warnf("failed to load state: %v", err)
return routeselector.NewRouteSelector()
}
if state := m.stateManager.GetState(state); state != nil {
if selector, ok := state.(*SelectorState); ok {
return (*routeselector.RouteSelector)(selector)
}
log.Warnf("failed to convert state with type %T to SelectorState", state)
}
return routeselector.NewRouteSelector()
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
@@ -278,10 +252,6 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list

View File

@@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
_, _, err = routeManager.Init()
_, _, err = routeManager.Init(nil)
require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil)

View File

@@ -21,7 +21,7 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager)
}
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil
}

View File

@@ -47,9 +47,10 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
mu sync.Mutex
refCountMu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
idMu sync.Mutex
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
@@ -71,14 +72,13 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
}
// LoadData loads the data from the existing counter
// The passed counter should not be used any longer after calling this function.
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
rm.mu.Lock()
defer rm.mu.Unlock()
existingCounter.mu.Lock()
defer existingCounter.mu.Unlock()
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
@@ -87,8 +87,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
ref, ok := rm.refCountMap[key]
return ref, ok
@@ -97,13 +97,9 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
return rm.increment(key, in)
}
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -130,10 +126,10 @@ func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
ref, err := rm.increment(key, in)
ref, err := rm.Increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
@@ -145,12 +141,9 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.mu.Lock()
defer rm.mu.Unlock()
return rm.decrement(key)
}
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key]
if !ok {
logCallerF("No reference found for key %v", key)
@@ -175,12 +168,12 @@ func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for _, key := range rm.idMap[id] {
if _, err := rm.decrement(key); err != nil {
if _, err := rm.Decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -191,8 +184,10 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
for key := range rm.refCountMap {
@@ -211,8 +206,10 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
clear(rm.refCountMap)
clear(rm.idMap)
@@ -220,9 +217,6 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
rm.mu.Lock()
defer rm.mu.Unlock()
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
@@ -234,9 +228,6 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.mu.Lock()
defer rm.mu.Unlock()
var temp struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
IDMap map[string][]Key `json:"idMap"`
@@ -247,13 +238,6 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
rm.refCountMap = temp.RefCountMap
rm.idMap = temp.IDMap
if temp.RefCountMap == nil {
temp.RefCountMap = map[Key]Ref[O]{}
}
if temp.IDMap == nil {
temp.IDMap = map[string][]Key{}
}
return nil
}

View File

@@ -1,19 +0,0 @@
package routemanager
import (
"github.com/netbirdio/netbird/client/internal/routeselector"
)
type SelectorState routeselector.RouteSelector
func (s *SelectorState) Name() string {
return "routeselector_state"
}
func (s *SelectorState) MarshalJSON() ([]byte, error) {
return (*routeselector.RouteSelector)(s).MarshalJSON()
}
func (s *SelectorState) UnmarshalJSON(data []byte) error {
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
}

View File

@@ -2,28 +2,31 @@ package systemops
import (
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
type ShutdownState ExclusionCounter
type ShutdownState struct {
Counter *ExclusionCounter `json:"counter,omitempty"`
mu sync.RWMutex
}
func (s *ShutdownState) Name() string {
return "route_state"
}
func (s *ShutdownState) Cleanup() error {
s.mu.RLock()
defer s.mu.RUnlock()
if s.Counter == nil {
return nil
}
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s))
sysops.refCounter.LoadData(s.Counter)
return sysops.refCounter.Flush()
}
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
return (*ExclusionCounter)(s).MarshalJSON()
}
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

View File

@@ -57,19 +57,30 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore
}
r.updateState(stateManager)
return nexthop, err
},
r.removeFromRouteTable,
func(prefix netip.Prefix, nexthop Nexthop) error {
// remove from state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)
return r.removeFromRouteTable(prefix, nexthop)
},
)
r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager)
return r.setupHooks(initAddresses)
}
// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
state := getState(stateManager)
state.Counter = r.refCounter
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
@@ -325,7 +336,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
@@ -336,8 +347,6 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return fmt.Errorf("adding route reference: %v", err)
}
r.updateState(stateManager)
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
@@ -345,8 +354,6 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
}
@@ -525,3 +532,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}
func getState(stateManager *statemanager.Manager) *ShutdownState {
var shutdownState *ShutdownState
if state := stateManager.GetState(shutdownState); state != nil {
shutdownState = state.(*ShutdownState)
} else {
shutdownState = &ShutdownState{}
}
return shutdownState
}

View File

@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
// setIsLegacy sets the legacy routing setup
@@ -92,6 +92,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
return r.setupRefCounter(initAddresses, stateManager)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
@@ -112,17 +123,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
}
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
return nil, nil, nil
}
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}

View File

@@ -230,13 +230,10 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
if idx != 0 {
intf, err := net.InterfaceByIndex(idx)
if err != nil {
log.Warnf("failed to get interface name for index %d: %v", idx, err)
update.Interface = &net.Interface{
Index: idx,
}
} else {
update.Interface = intf
return update, fmt.Errorf("get interface name: %w", err)
}
update.Interface = intf
}
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)

View File

@@ -1,10 +1,8 @@
package routeselector
import (
"encoding/json"
"fmt"
"slices"
"sync"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
@@ -14,7 +12,6 @@ import (
)
type RouteSelector struct {
mu sync.RWMutex
selectedRoutes map[route.NetID]struct{}
selectAll bool
}
@@ -29,9 +26,6 @@ func NewRouteSelector() *RouteSelector {
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if !appendRoute {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
@@ -52,9 +46,6 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{}
}
@@ -62,9 +53,6 @@ func (rs *RouteSelector) SelectAllRoutes() {
// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{}
@@ -88,18 +76,12 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock()
defer rs.mu.Unlock()
rs.selectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{}
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.selectAll {
return true
}
@@ -109,9 +91,6 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.selectAll {
return maps.Clone(routes)
}
@@ -124,49 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
return filtered
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
defer rs.mu.RUnlock()
return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}{
SelectAll: rs.selectAll,
SelectedRoutes: rs.selectedRoutes,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.mu.Lock()
defer rs.mu.Unlock()
// Check for null or empty JSON
if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true
return nil
}
var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
return nil
}

View File

@@ -16,39 +16,14 @@ import (
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
)
const (
errStateNotRegistered = "state %s not registered"
errLoadStateFile = "load state file: %w"
)
// State interface defines the methods that all state types must implement
type State interface {
Name() string
}
// CleanableState interface extends State with cleanup capability
type CleanableState interface {
State
Cleanup() error
}
// RawState wraps raw JSON data for unregistered states
type RawState struct {
data json.RawMessage
}
func (r *RawState) Name() string {
return "" // This is a placeholder implementation
}
// MarshalJSON implements json.Marshaler to preserve the original JSON
func (r *RawState) MarshalJSON() ([]byte, error) {
return r.data, nil
}
// Manager handles the persistence and management of various states
type Manager struct {
mu sync.Mutex
@@ -98,15 +73,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
m.cancel()
if m.cancel != nil {
m.cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
return nil
@@ -164,7 +139,7 @@ func (m *Manager) setState(name string, state State) error {
defer m.mu.Unlock()
if _, exists := m.states[name]; !exists {
return fmt.Errorf(errStateNotRegistered, name)
return fmt.Errorf("state %s not registered", name)
}
m.states[name] = state
@@ -173,63 +148,6 @@ func (m *Manager) setState(name string, state State) error {
return nil
}
// DeleteStateByName handles deletion of states without cleanup.
// It doesn't require the state to be registered.
func (m *Manager) DeleteStateByName(stateName string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
}
if _, exists := rawStates[stateName]; !exists {
return fmt.Errorf("state %s not found", stateName)
}
// Mark state as deleted by setting it to nil and marking it dirty
m.states[stateName] = nil
m.dirty[stateName] = struct{}{}
return nil
}
// DeleteAllStates removes all states.
func (m *Manager) DeleteAllStates() (int, error) {
if m == nil {
return 0, nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return 0, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return 0, nil
}
count := len(rawStates)
// Mark all states as deleted and dirty
for name := range rawStates {
m.states[name] = nil
m.dirty[name] = struct{}{}
}
return count, nil
}
func (m *Manager) periodicStateSave(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
@@ -260,18 +178,25 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}
bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
done := make(chan error, 1)
start := time.Now()
go func() {
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
data, err := json.MarshalIndent(m.states, "", " ")
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
}()
select {
@@ -283,175 +208,63 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
}
log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
clear(m.dirty)
return nil
}
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) {
// loadState loads the existing state from the state file
func (m *Manager) loadState() error {
data, err := os.ReadFile(m.filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist")
return nil, nil // nolint:nilnil
return nil
}
return nil, fmt.Errorf("read state file: %w", err)
return fmt.Errorf("read state file: %w", err)
}
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
if deleteCorrupt {
log.Warn("State file appears to be corrupted, attempting to delete it", err)
if err := os.Remove(m.filePath); err != nil {
log.Errorf("Failed to delete corrupted state file: %v", err)
} else {
log.Info("State file deleted")
}
log.Warn("State file appears to be corrupted, attempting to delete it")
if err := os.Remove(m.filePath); err != nil {
log.Errorf("Failed to delete corrupted state file: %v", err)
} else {
log.Info("State file deleted")
}
return nil, fmt.Errorf("unmarshal states: %w", err)
return fmt.Errorf("unmarshal states: %w", err)
}
return rawStates, nil
}
var merr *multierror.Error
// loadSingleRawState unmarshals a raw state into a concrete state object
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
stateType, ok := m.stateTypes[name]
if !ok {
return nil, fmt.Errorf(errStateNotRegistered, name)
}
for name, rawState := range rawStates {
stateType, ok := m.stateTypes[name]
if !ok {
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
continue
}
if string(rawState) == "null" {
return nil, nil //nolint:nilnil
}
if string(rawState) == "null" {
continue
}
statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil {
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
}
statePtr := reflect.New(stateType).Interface().(State)
if err := json.Unmarshal(rawState, statePtr); err != nil {
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
continue
}
return statePtr, nil
}
// LoadState loads a specific state from the state file
func (m *Manager) LoadState(state State) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
name := state.Name()
rawState, exists := rawStates[name]
if !exists {
return nil
}
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
m.states[name] = loadedState
if loadedState != nil {
m.states[name] = statePtr
log.Debugf("loaded state: %s", name)
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
// cleanupSingleState handles the cleanup of a specific state and returns any error.
// The caller must hold the mutex.
func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error {
// For unregistered states, preserve the raw JSON
if _, registered := m.stateTypes[name]; !registered {
m.states[name] = &RawState{data: rawState}
return nil
}
// Load the state
loadedState, err := m.loadSingleRawState(name, rawState)
if err != nil {
return err
}
if loadedState == nil {
return nil
}
// Check if state supports cleanup
cleanableState, isCleanable := loadedState.(CleanableState)
if !isCleanable {
// If it doesn't support cleanup, keep it as-is
m.states[name] = loadedState
return nil
}
// Perform cleanup
log.Infof("cleaning up state %s", name)
if err := cleanableState.Cleanup(); err != nil {
// On cleanup error, preserve the state
m.states[name] = loadedState
return fmt.Errorf("cleanup state: %w", err)
}
// Successfully cleaned up - mark for deletion
m.states[name] = nil
m.dirty[name] = struct{}{}
return nil
}
// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState.
// Returns an error if the state doesn't exist, isn't registered, or cleanup fails.
func (m *Manager) CleanupStateByName(name string) error {
if m == nil {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
// Check if state is registered
if _, registered := m.stateTypes[name]; !registered {
return fmt.Errorf(errStateNotRegistered, name)
}
// Load raw states from file
rawStates, err := m.loadStateFile(false)
if err != nil {
return err
}
if rawStates == nil {
return nil
}
// Check if state exists in file
rawState, exists := rawStates[name]
if !exists {
return nil
}
if err := m.cleanupSingleState(name, rawState); err != nil {
return fmt.Errorf("%s: %w", name, err)
}
return nil
}
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
// Unregistered states are preserved in their original state.
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
// If the cleanup is successful, the state is marked for deletion.
func (m *Manager) PerformCleanup() error {
if m == nil {
return nil
@@ -460,63 +273,26 @@ func (m *Manager) PerformCleanup() error {
m.mu.Lock()
defer m.mu.Unlock()
// Load raw states from file
rawStates, err := m.loadStateFile(true)
if err != nil {
return fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil
if err := m.loadState(); err != nil {
log.Warnf("Failed to load state during cleanup: %v", err)
}
var merr *multierror.Error
for name, state := range m.states {
if state == nil {
// If no state was found in the state file, we don't mark the state dirty nor return an error
continue
}
// Process each state in the file
for name, rawState := range rawStates {
if err := m.cleanupSingleState(name, rawState); err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err))
log.Infof("client was not shut down properly, cleaning up %s", name)
if err := state.Cleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
} else {
// mark for deletion on cleanup success
m.states[name] = nil
m.dirty[name] = struct{}{}
}
}
return nberrors.FormatErrorOrNil(merr)
}
// GetSavedStateNames returns all state names that are currently saved in the state file.
func (m *Manager) GetSavedStateNames() ([]string, error) {
if m == nil {
return nil, nil
}
rawStates, err := m.loadStateFile(false)
if err != nil {
return nil, fmt.Errorf(errLoadStateFile, err)
}
if rawStates == nil {
return nil, nil
}
var states []string
for name, state := range rawStates {
if len(state) != 0 && string(state) != "null" {
states = append(states, name)
}
}
return states, nil
}
func marshalWithPanicRecovery(v any) ([]byte, error) {
var bs []byte
var err error
func() {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during marshal: %v", r)
}
}()
bs, err = json.Marshal(v)
}()
return bs, err
}

View File

@@ -4,20 +4,32 @@ import (
"os"
"path/filepath"
"runtime"
log "github.com/sirupsen/logrus"
)
// GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined.
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
func GetDefaultStatePath() string {
var path string
switch runtime.GOOS {
case "windows":
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
return "/var/lib/netbird/state.json"
path = "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
return "/var/db/netbird/state.json"
path = "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
}
return ""
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return path
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v4.23.4
// protoc v3.21.12
// source: daemon.proto
package proto
@@ -2103,434 +2103,6 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{30}
}
// State represents a daemon state entry
type State struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
}
func (x *State) Reset() {
*x = State{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[31]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *State) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*State) ProtoMessage() {}
func (x *State) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[31]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use State.ProtoReflect.Descriptor instead.
func (*State) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{31}
}
func (x *State) GetName() string {
if x != nil {
return x.Name
}
return ""
}
// ListStatesRequest is empty as it requires no parameters
type ListStatesRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *ListStatesRequest) Reset() {
*x = ListStatesRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[32]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListStatesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListStatesRequest) ProtoMessage() {}
func (x *ListStatesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[32]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead.
func (*ListStatesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{32}
}
// ListStatesResponse contains a list of states
type ListStatesResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
States []*State `protobuf:"bytes,1,rep,name=states,proto3" json:"states,omitempty"`
}
func (x *ListStatesResponse) Reset() {
*x = ListStatesResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[33]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ListStatesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListStatesResponse) ProtoMessage() {}
func (x *ListStatesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[33]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead.
func (*ListStatesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{33}
}
func (x *ListStatesResponse) GetStates() []*State {
if x != nil {
return x.States
}
return nil
}
// CleanStateRequest for cleaning states
type CleanStateRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *CleanStateRequest) Reset() {
*x = CleanStateRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[34]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CleanStateRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CleanStateRequest) ProtoMessage() {}
func (x *CleanStateRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[34]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead.
func (*CleanStateRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{34}
}
func (x *CleanStateRequest) GetStateName() string {
if x != nil {
return x.StateName
}
return ""
}
func (x *CleanStateRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
// CleanStateResponse contains the result of the clean operation
type CleanStateResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
CleanedStates int32 `protobuf:"varint,1,opt,name=cleaned_states,json=cleanedStates,proto3" json:"cleaned_states,omitempty"`
}
func (x *CleanStateResponse) Reset() {
*x = CleanStateResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[35]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CleanStateResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CleanStateResponse) ProtoMessage() {}
func (x *CleanStateResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[35]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead.
func (*CleanStateResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{35}
}
func (x *CleanStateResponse) GetCleanedStates() int32 {
if x != nil {
return x.CleanedStates
}
return 0
}
// DeleteStateRequest for deleting states
type DeleteStateRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
}
func (x *DeleteStateRequest) Reset() {
*x = DeleteStateRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[36]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeleteStateRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeleteStateRequest) ProtoMessage() {}
func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[36]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead.
func (*DeleteStateRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{36}
}
func (x *DeleteStateRequest) GetStateName() string {
if x != nil {
return x.StateName
}
return ""
}
func (x *DeleteStateRequest) GetAll() bool {
if x != nil {
return x.All
}
return false
}
// DeleteStateResponse contains the result of the delete operation
type DeleteStateResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
DeletedStates int32 `protobuf:"varint,1,opt,name=deleted_states,json=deletedStates,proto3" json:"deleted_states,omitempty"`
}
func (x *DeleteStateResponse) Reset() {
*x = DeleteStateResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[37]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *DeleteStateResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*DeleteStateResponse) ProtoMessage() {}
func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[37]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead.
func (*DeleteStateResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{37}
}
func (x *DeleteStateResponse) GetDeletedStates() int32 {
if x != nil {
return x.DeletedStates
}
return 0
}
type SetNetworkMapPersistenceRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
}
func (x *SetNetworkMapPersistenceRequest) Reset() {
*x = SetNetworkMapPersistenceRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[38]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetNetworkMapPersistenceRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetNetworkMapPersistenceRequest) ProtoMessage() {}
func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[38]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead.
func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{38}
}
func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool {
if x != nil {
return x.Enabled
}
return false
}
type SetNetworkMapPersistenceResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *SetNetworkMapPersistenceResponse) Reset() {
*x = SetNetworkMapPersistenceResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[39]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *SetNetworkMapPersistenceResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*SetNetworkMapPersistenceResponse) ProtoMessage() {}
func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[39]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead.
func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{39}
}
var File_daemon_proto protoreflect.FileDescriptor
var file_daemon_proto_rawDesc = []byte{
@@ -2827,116 +2399,66 @@ var file_daemon_proto_rawDesc = []byte{
0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74,
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d,
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a,
0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74,
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22,
0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61,
0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e,
0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08,
0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74,
0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63,
0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74,
0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74,
0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74,
0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74,
0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02,
0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c,
0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74,
0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65,
0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65,
0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e,
0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61,
0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f,
0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c,
0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05,
0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52,
0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04,
0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10,
0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a,
0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36,
0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53,
0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69,
0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69,
0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65,
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74,
0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62,
0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a,
0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53,
0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74,
0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65,
0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e,
0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12,
0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41,
0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09,
0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41,
0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55,
0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74,
0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65,
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f,
0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53,
0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65,
0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69,
0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52,
0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65,
0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c,
0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67,
0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
}
var (
@@ -2952,59 +2474,50 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32)
var file_daemon_proto_goTypes = []interface{}{
(LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest
(*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 5: daemon.UpRequest
(*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusRequest)(nil), // 7: daemon.StatusRequest
(*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownRequest)(nil), // 9: daemon.DownRequest
(*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*PeerState)(nil), // 13: daemon.PeerState
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*SignalState)(nil), // 15: daemon.SignalState
(*ManagementState)(nil), // 16: daemon.ManagementState
(*RelayState)(nil), // 17: daemon.RelayState
(*NSGroupState)(nil), // 18: daemon.NSGroupState
(*FullStatus)(nil), // 19: daemon.FullStatus
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*IPList)(nil), // 24: daemon.IPList
(*Route)(nil), // 25: daemon.Route
(*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse
(*State)(nil), // 32: daemon.State
(*ListStatesRequest)(nil), // 33: daemon.ListStatesRequest
(*ListStatesResponse)(nil), // 34: daemon.ListStatesResponse
(*CleanStateRequest)(nil), // 35: daemon.CleanStateRequest
(*CleanStateResponse)(nil), // 36: daemon.CleanStateResponse
(*DeleteStateRequest)(nil), // 37: daemon.DeleteStateRequest
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
nil, // 41: daemon.Route.ResolvedIPsEntry
(*durationpb.Duration)(nil), // 42: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp
(LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest
(*LoginResponse)(nil), // 2: daemon.LoginResponse
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
(*UpRequest)(nil), // 5: daemon.UpRequest
(*UpResponse)(nil), // 6: daemon.UpResponse
(*StatusRequest)(nil), // 7: daemon.StatusRequest
(*StatusResponse)(nil), // 8: daemon.StatusResponse
(*DownRequest)(nil), // 9: daemon.DownRequest
(*DownResponse)(nil), // 10: daemon.DownResponse
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
(*PeerState)(nil), // 13: daemon.PeerState
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
(*SignalState)(nil), // 15: daemon.SignalState
(*ManagementState)(nil), // 16: daemon.ManagementState
(*RelayState)(nil), // 17: daemon.RelayState
(*NSGroupState)(nil), // 18: daemon.NSGroupState
(*FullStatus)(nil), // 19: daemon.FullStatus
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
(*IPList)(nil), // 24: daemon.IPList
(*Route)(nil), // 25: daemon.Route
(*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse
(*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse
nil, // 32: daemon.Route.ResolvedIPsEntry
(*durationpb.Duration)(nil), // 33: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -3012,48 +2525,39 @@ var file_daemon_proto_depIdxs = []int32{
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry
32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList
1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest
7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest
11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse
8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
33, // [33:49] is the sub-list for method output_type
17, // [17:33] is the sub-list for method input_type
17, // [17:17] is the sub-list for extension type_name
17, // [17:17] is the sub-list for extension extendee
0, // [0:17] is the sub-list for field type_name
24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList
1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest
7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest
11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse
8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
28, // [28:40] is the sub-list for method output_type
16, // [16:28] is the sub-list for method input_type
16, // [16:16] is the sub-list for extension type_name
16, // [16:16] is the sub-list for extension extendee
0, // [0:16] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -3434,114 +2938,6 @@ func file_daemon_proto_init() {
return nil
}
}
file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*State); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListStatesRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ListStatesResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CleanStateRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CleanStateResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeleteStateRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*DeleteStateResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetNetworkMapPersistenceRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetNetworkMapPersistenceResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
type x struct{}
@@ -3550,7 +2946,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 1,
NumMessages: 41,
NumMessages: 32,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -45,20 +45,7 @@ service DaemonService {
// SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
// List all states
rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {}
// Clean specific state or all states
rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {}
// Delete specific state or all states
rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {}
// SetNetworkMapPersistence enables or disables network map persistence
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
}
};
message LoginRequest {
// setupKey wiretrustee setup key.
@@ -306,46 +293,4 @@ message SetLogLevelRequest {
}
message SetLogLevelResponse {
}
// State represents a daemon state entry
message State {
string name = 1;
}
// ListStatesRequest is empty as it requires no parameters
message ListStatesRequest {}
// ListStatesResponse contains a list of states
message ListStatesResponse {
repeated State states = 1;
}
// CleanStateRequest for cleaning states
message CleanStateRequest {
string state_name = 1;
bool all = 2;
}
// CleanStateResponse contains the result of the clean operation
message CleanStateResponse {
int32 cleaned_states = 1;
}
// DeleteStateRequest for deleting states
message DeleteStateRequest {
string state_name = 1;
bool all = 2;
}
// DeleteStateResponse contains the result of the delete operation
message DeleteStateResponse {
int32 deleted_states = 1;
}
message SetNetworkMapPersistenceRequest {
bool enabled = 1;
}
message SetNetworkMapPersistenceResponse {}
}

View File

@@ -43,14 +43,6 @@ type DaemonServiceClient interface {
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
// List all states
ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error)
// Clean specific state or all states
CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error)
// Delete specific state or all states
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
}
type daemonServiceClient struct {
@@ -169,42 +161,6 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
return out, nil
}
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
out := new(ListStatesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
out := new(CleanStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
out := new(DeleteStateResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
out := new(SetNetworkMapPersistenceResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -234,14 +190,6 @@ type DaemonServiceServer interface {
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
// List all states
ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error)
// Clean specific state or all states
CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error)
// Delete specific state or all states
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
// SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -285,18 +233,6 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
}
func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented")
}
func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented")
}
func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented")
}
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -526,78 +462,6 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListStatesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ListStates(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ListStates",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(CleanStateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).CleanState(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/CleanState",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DeleteStateRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).DeleteState(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/DeleteState",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetNetworkMapPersistenceRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -653,22 +517,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler,
},
{
MethodName: "ListStates",
Handler: _DaemonService_ListStates_Handler,
},
{
MethodName: "CleanState",
Handler: _DaemonService_CleanState_Handler,
},
{
MethodName: "DeleteState",
Handler: _DaemonService_DeleteState_Handler,
},
{
MethodName: "SetNetworkMapPersistence",
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "daemon.proto",

View File

@@ -5,44 +5,32 @@ package server
import (
"archive/zip"
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
const readmeContent = `Netbird debug bundle
This debug bundle contains the following files:
status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized client log file of the NetBird client.
netbird.err: Most recent, anonymized stderr log file of the NetBird client.
netbird.out: Most recent, anonymized stdout log file of the NetBird client.
client.log: Most recent, anonymized log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states.
Anonymization Process
@@ -62,32 +50,8 @@ Domains
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
Reoccuring domain names are replaced with the same anonymized domain.
Network Map
The network_map.json file contains the following anonymized information:
- Peer configurations (addresses, FQDNs, DNS settings)
- Remote and offline peer information (allowed IPs, FQDNs)
- Routes (network ranges, associated domains)
- DNS configuration (nameservers, domains, custom zones)
- Firewall rules (peer IPs, source/destination ranges)
SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
State File
The state.json file contains anonymized internal state information of the NetBird client, including:
- DNS settings and configuration
- Firewall rules
- Exclusion routes
- Route selection
- Other internal states that may be present
The state file follows the same anonymization rules as other files:
- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
- Domain names are consistently anonymized
- Technical identifiers and non-sensitive data remain unchanged
Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
Network Interfaces
The interfaces.txt file contains information about network interfaces, including:
- Interface name
@@ -108,12 +72,6 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization.
`
const (
clientLogFile = "client.log"
errorLogFile = "netbird.err"
stdoutLogFile = "netbird.out"
)
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -161,27 +119,19 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
seedFromStatus(anonymizer, &status)
if err := s.addConfig(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add config to debug bundle: %v", err)
return fmt.Errorf("add config: %w", err)
}
if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
return fmt.Errorf("add routes: %w", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
return fmt.Errorf("add interfaces: %w", err)
}
}
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
return fmt.Errorf("add network map: %w", err)
}
if err := s.addStateFile(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add state file to debug bundle: %v", err)
}
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
@@ -270,16 +220,15 @@ func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
}
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
routes, err := systemops.GetRoutesFromTable()
if err != nil {
return fmt.Errorf("get routes: %w", err)
}
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent)
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
if routes, err := systemops.GetRoutesFromTable(); err != nil {
log.Errorf("Failed to get routes: %v", err)
} else {
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent)
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
}
return nil
}
@@ -299,106 +248,14 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym
return nil
}
func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
networkMap, err := s.getLatestNetworkMap()
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
logFile, err := os.Open(s.logFile)
if err != nil {
// Skip if network map is not available, but log it
log.Debugf("skipping empty network map in debug bundle: %v", err)
return nil
}
if req.GetAnonymize() {
if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil {
return fmt.Errorf("anonymize network map: %w", err)
}
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
Indent: " ",
AllowPartial: true,
}
jsonBytes, err := options.Marshal(networkMap)
if err != nil {
return fmt.Errorf("generate json: %w", err)
}
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil {
return fmt.Errorf("add network map to zip: %w", err)
}
return nil
}
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
path := statemanager.GetDefaultStatePath()
if path == "" {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return fmt.Errorf("read state file: %w", err)
}
if req.GetAnonymize() {
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
return fmt.Errorf("unmarshal states: %w", err)
}
if err := anonymizeStateFile(&rawStates, anonymizer); err != nil {
return fmt.Errorf("anonymize state file: %w", err)
}
bs, err := json.MarshalIndent(rawStates, "", " ")
if err != nil {
return fmt.Errorf("marshal states: %w", err)
}
data = bs
}
if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil {
return fmt.Errorf("add state file to zip: %w", err)
}
return nil
}
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logDir := filepath.Dir(s.logFile)
if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil {
return fmt.Errorf("add client log file to zip: %w", err)
}
errLogPath := filepath.Join(logDir, errorLogFile)
if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil {
log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
}
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil {
log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
}
return nil
}
// addSingleLogfile adds a single log file to the archive
func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
logFile, err := os.Open(logPath)
if err != nil {
return fmt.Errorf("open log file %s: %w", targetName, err)
return fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("Failed to close log file %s: %v", targetName, err)
log.Errorf("Failed to close original log file: %v", err)
}
}()
@@ -407,55 +264,45 @@ func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBu
var writer *io.PipeWriter
logReader, writer = io.Pipe()
go anonymizeLog(logFile, writer, anonymizer)
go s.anonymize(logFile, writer, anonymizer)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, targetName); err != nil {
return fmt.Errorf("add %s to zip: %w", targetName, err)
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
return fmt.Errorf("add log file to zip: %w", err)
}
return nil
}
// getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled
func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) {
if s.connectClient == nil {
return nil, errors.New("connect client is not initialized")
}
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
engine := s.connectClient.Engine()
if engine == nil {
return nil, errors.New("engine is not initialized")
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
}
networkMap, err := engine.GetLatestNetworkMap()
if err != nil {
return nil, fmt.Errorf("get latest network map: %w", err)
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return
}
if networkMap == nil {
return nil, errors.New("network map is not available")
}
return networkMap, nil
}
// GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
level := ParseLogLevel(log.GetLevel().String())
return &proto.GetLogLevelResponse{Level: level}, nil
}
// SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
level, err := log.ParseLevel(req.Level.String())
if err != nil {
return nil, fmt.Errorf("invalid log level: %w", err)
@@ -466,20 +313,6 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
return &proto.SetLogLevelResponse{}, nil
}
// SetNetworkMapPersistence sets the network map persistence for the server.
func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetworkMapPersistenceRequest) (*proto.SetNetworkMapPersistenceResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
enabled := req.GetEnabled()
s.persistNetworkMap = enabled
if s.connectClient != nil {
s.connectClient.SetNetworkMapPersistence(enabled)
}
return &proto.SetNetworkMapPersistenceResponse{}, nil
}
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
@@ -625,26 +458,6 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an
return builder.String()
}
func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
return
}
}
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
return
}
}
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
anonymizedIPs := make([]string, len(ips))
for i, ip := range ips {
@@ -671,248 +484,3 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s
}
return anonymizedIPs
}
func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error {
if networkMap.PeerConfig != nil {
anonymizePeerConfig(networkMap.PeerConfig, anonymizer)
}
for _, peer := range networkMap.RemotePeers {
anonymizeRemotePeer(peer, anonymizer)
}
for _, peer := range networkMap.OfflinePeers {
anonymizeRemotePeer(peer, anonymizer)
}
for _, r := range networkMap.Routes {
anonymizeRoute(r, anonymizer)
}
if networkMap.DNSConfig != nil {
anonymizeDNSConfig(networkMap.DNSConfig, anonymizer)
}
for _, rule := range networkMap.FirewallRules {
anonymizeFirewallRule(rule, anonymizer)
}
for _, rule := range networkMap.RoutesFirewallRules {
anonymizeRouteFirewallRule(rule, anonymizer)
}
return nil
}
func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) {
if config == nil {
return
}
if addr, err := netip.ParseAddr(config.Address); err == nil {
config.Address = anonymizer.AnonymizeIP(addr).String()
}
if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 {
config.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
}
config.Dns = anonymizer.AnonymizeString(config.Dns)
config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn)
}
func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) {
if peer == nil {
return
}
for i, ip := range peer.AllowedIps {
// Try to parse as prefix first (CIDR)
if prefix, err := netip.ParsePrefix(ip); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
} else if addr, err := netip.ParseAddr(ip); err == nil {
peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String()
}
}
peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn)
if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 {
peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
}
}
func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) {
if route == nil {
return
}
if prefix, err := netip.ParsePrefix(route.Network); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
for i, domain := range route.Domains {
route.Domains[i] = anonymizer.AnonymizeDomain(domain)
}
route.NetID = anonymizer.AnonymizeString(route.NetID)
}
func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) {
if config == nil {
return
}
anonymizeNameServerGroups(config.NameServerGroups, anonymizer)
anonymizeCustomZones(config.CustomZones, anonymizer)
}
func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) {
for _, group := range groups {
anonymizeServers(group.NameServers, anonymizer)
anonymizeDomains(group.Domains, anonymizer)
}
}
func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) {
for _, server := range servers {
if addr, err := netip.ParseAddr(server.IP); err == nil {
server.IP = anonymizer.AnonymizeIP(addr).String()
}
}
}
func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) {
for i, domain := range domains {
domains[i] = anonymizer.AnonymizeDomain(domain)
}
}
func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) {
for _, zone := range zones {
zone.Domain = anonymizer.AnonymizeDomain(zone.Domain)
anonymizeRecords(zone.Records, anonymizer)
}
}
func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
for _, record := range records {
record.Name = anonymizer.AnonymizeDomain(record.Name)
anonymizeRData(record, anonymizer)
}
}
func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
switch record.Type {
case 1, 28: // A or AAAA record
if addr, err := netip.ParseAddr(record.RData); err == nil {
record.RData = anonymizer.AnonymizeIP(addr).String()
}
default:
record.RData = anonymizer.AnonymizeString(record.RData)
}
}
func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) {
if rule == nil {
return
}
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
}
}
func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) {
if rule == nil {
return
}
for i, sourceRange := range rule.SourceRanges {
if prefix, err := netip.ParsePrefix(sourceRange); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
}
if prefix, err := netip.ParsePrefix(rule.Destination); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
}
func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
for name, rawState := range *rawStates {
if string(rawState) == "null" {
continue
}
var state map[string]any
if err := json.Unmarshal(rawState, &state); err != nil {
return fmt.Errorf("unmarshal state %s: %w", name, err)
}
state = anonymizeValue(state, anonymizer).(map[string]any)
bs, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state %s: %w", name, err)
}
(*rawStates)[name] = bs
}
return nil
}
func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
switch v := value.(type) {
case string:
return anonymizeString(v, anonymizer)
case map[string]any:
return anonymizeMap(v, anonymizer)
case []any:
return anonymizeSlice(v, anonymizer)
}
return value
}
func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
if prefix, err := netip.ParsePrefix(v); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
if ip, err := netip.ParseAddr(v); err == nil {
return anonymizer.AnonymizeIP(ip).String()
}
return anonymizer.AnonymizeString(v)
}
func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
result := make(map[string]any, len(v))
for key, val := range v {
newKey := anonymizeMapKey(key, anonymizer)
result[newKey] = anonymizeValue(val, anonymizer)
}
return result
}
func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
if prefix, err := netip.ParsePrefix(key); err == nil {
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
}
if ip, err := netip.ParseAddr(key); err == nil {
return anonymizer.AnonymizeIP(ip).String()
}
return key
}
func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
for i, val := range v {
v[i] = anonymizeValue(val, anonymizer)
}
return v
}

View File

@@ -1,430 +0,0 @@
package server
import (
"encoding/json"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
"test_state": mustMarshal(map[string]any{
// Test simple fields
"public_ip": "203.0.113.1",
"private_ip": "192.168.1.1",
"protected_ip": "100.64.0.1",
"well_known_ip": "8.8.8.8",
"ipv6_addr": "2001:db8::1",
"private_ipv6": "fd00::1",
"domain": "test.example.com",
"uri": "stun:stun.example.com:3478",
"uri_with_ip": "turn:203.0.113.1:3478",
"netbird_domain": "device.netbird.cloud",
// Test CIDR ranges
"public_cidr": "203.0.113.0/24",
"private_cidr": "192.168.0.0/16",
"protected_cidr": "100.64.0.0/10",
"ipv6_cidr": "2001:db8::/32",
"private_ipv6_cidr": "fd00::/8",
// Test nested structures
"nested": map[string]any{
"ip": "203.0.113.2",
"domain": "nested.example.com",
"more_nest": map[string]any{
"ip": "203.0.113.3",
"domain": "deep.example.com",
},
},
// Test arrays
"string_array": []any{
"203.0.113.4",
"test1.example.com",
"test2.example.com",
},
"object_array": []any{
map[string]any{
"ip": "203.0.113.5",
"domain": "array1.example.com",
},
map[string]any{
"ip": "203.0.113.6",
"domain": "array2.example.com",
},
},
// Test multiple occurrences of same value
"duplicate_ip": "203.0.113.1", // Same as public_ip
"duplicate_domain": "test.example.com", // Same as domain
// Test URIs with various schemes
"stun_uri": "stun:stun.example.com:3478",
"turns_uri": "turns:turns.example.com:5349",
"http_uri": "http://web.example.com:80",
"https_uri": "https://secure.example.com:443",
// Test strings that might look like IPs but aren't
"not_ip": "300.300.300.300",
"partial_ip": "192.168",
"ip_like_string": "1234.5678",
// Test mixed content strings
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
// Test empty and special values
"empty_string": "",
"null_value": nil,
"numeric_value": 42,
"boolean_value": true,
}),
"route_state": mustMarshal(map[string]any{
"routes": []any{
map[string]any{
"network": "203.0.113.0/24",
"gateway": "203.0.113.1",
"domains": []any{
"route1.example.com",
"route2.example.com",
},
},
map[string]any{
"network": "2001:db8::/32",
"gateway": "2001:db8::1",
"domains": []any{
"route3.example.com",
"route4.example.com",
},
},
},
// Test map with IP/CIDR keys
"refCountMap": map[string]any{
"203.0.113.1/32": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"2001:db8::1/128": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "fe80::1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
},
},
},
}),
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Pre-seed the domains we need to verify in the test assertions
anonymizer.AnonymizeDomain("test.example.com")
anonymizer.AnonymizeDomain("nested.example.com")
anonymizer.AnonymizeDomain("deep.example.com")
anonymizer.AnonymizeDomain("array1.example.com")
err := anonymizeStateFile(&testState, anonymizer)
require.NoError(t, err)
// Helper function to unmarshal and get nested values
var state map[string]any
err = json.Unmarshal(testState["test_state"], &state)
require.NoError(t, err)
// Test null state remains unchanged
require.Equal(t, "null", string(testState["null_state"]))
// Basic assertions
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
// CIDR ranges
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
// Nested structures
nested := state["nested"].(map[string]any)
assert.NotEqual(t, "203.0.113.2", nested["ip"])
assert.NotEqual(t, "nested.example.com", nested["domain"])
moreNest := nested["more_nest"].(map[string]any)
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
// Arrays
strArray := state["string_array"].([]any)
assert.NotEqual(t, "203.0.113.4", strArray[0])
assert.NotEqual(t, "test1.example.com", strArray[1])
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
objArray := state["object_array"].([]any)
firstObj := objArray[0].(map[string]any)
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
// Duplicate values should be anonymized consistently
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
assert.Equal(t, state["domain"], state["duplicate_domain"])
// URIs
assert.NotContains(t, state["stun_uri"], "stun.example.com")
assert.NotContains(t, state["turns_uri"], "turns.example.com")
assert.NotContains(t, state["http_uri"], "web.example.com")
assert.NotContains(t, state["https_uri"], "secure.example.com")
// Non-IP strings should remain unchanged
assert.Equal(t, "300.300.300.300", state["not_ip"])
assert.Equal(t, "192.168", state["partial_ip"])
assert.Equal(t, "1234.5678", state["ip_like_string"])
// Mixed content should have IPs and domains replaced
mixedContent := state["mixed_content"].(string)
assert.NotContains(t, mixedContent, "203.0.113.1")
assert.NotContains(t, mixedContent, "test.example.com")
assert.Contains(t, mixedContent, "Server at ")
assert.Contains(t, mixedContent, " on port 80")
// Special values should remain unchanged
assert.Equal(t, "", state["empty_string"])
assert.Nil(t, state["null_value"])
assert.Equal(t, float64(42), state["numeric_value"])
assert.Equal(t, true, state["boolean_value"])
// Check route state
var routeState map[string]any
err = json.Unmarshal(testState["route_state"], &routeState)
require.NoError(t, err)
routes := routeState["routes"].([]any)
route1 := routes[0].(map[string]any)
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
assert.Contains(t, route1["network"], "/24")
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
domains := route1["domains"].([]any)
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
// Check map keys are anonymized
refCountMap := routeState["refCountMap"].(map[string]any)
hasPublicIPKey := false
hasIPv6Key := false
hasPrivateIPKey := false
for key := range refCountMap {
if strings.Contains(key, "203.0.113.1") {
hasPublicIPKey = true
}
if strings.Contains(key, "2001:db8::1") {
hasIPv6Key = true
}
if key == "10.0.0.1/32" {
hasPrivateIPKey = true
}
}
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
func TestAnonymizeNetworkMap(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{
AllowedIps: []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
},
Fqdn: "peer2.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
},
},
},
Routes: []*mgmProto.Route{
{
Network: "197.51.100.0/24",
Domains: []string{"prod.example.com", "staging.example.com"},
NetID: "net-123abc",
},
},
DNSConfig: &mgmProto.DNSConfig{
NameServerGroups: []*mgmProto.NameServerGroup{
{
NameServers: []*mgmProto.NameServer{
{IP: "8.8.8.8"},
{IP: "1.1.1.1"},
{IP: "203.0.113.53"},
},
Domains: []string{"example.com", "internal.example.com"},
},
},
CustomZones: []*mgmProto.CustomZone{
{
Domain: "custom.example.com",
Records: []*mgmProto.SimpleRecord{
{
Name: "www.custom.example.com",
Type: 1,
RData: "203.0.113.10",
},
{
Name: "internal.custom.example.com",
Type: 1,
RData: "192.168.1.10",
},
},
},
},
},
}
// Create anonymizer with test addresses
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Anonymize the network map
err := anonymizeNetworkMap(networkMap, anonymizer)
require.NoError(t, err)
// Test PeerConfig anonymization
peerCfg := networkMap.PeerConfig
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
// Verify DNS and FQDN are properly anonymized
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
// Test RemotePeers anonymization
remotePeer := networkMap.RemotePeers[0]
// Verify FQDN is anonymized
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
// Check that public IPs are anonymized but private IPs are preserved
for _, allowedIP := range remotePeer.AllowedIps {
ip, _, err := net.ParseCIDR(allowedIP)
require.NoError(t, err)
if ip.IsPrivate() || isInCGNATRange(ip) {
require.Contains(t, []string{
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
}, allowedIP)
} else {
require.NotContains(t, []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
}, allowedIP)
}
}
// Test Routes anonymization
route := networkMap.Routes[0]
require.NotEqual(t, "197.51.100.0/24", route.Network)
for _, domain := range route.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test DNS config anonymization
dnsConfig := networkMap.DNSConfig
nameServerGroup := dnsConfig.NameServerGroups[0]
// Verify well-known DNS servers are preserved
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
// Verify public DNS server is anonymized
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
// Verify domains are anonymized
for _, domain := range nameServerGroup.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test CustomZones anonymization
customZone := dnsConfig.CustomZones[0]
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
require.NotContains(t, customZone.Domain, "example.com")
// Verify records are properly anonymized
for _, record := range customZone.Records {
require.True(t, strings.HasSuffix(record.Name, ".domain"))
require.NotContains(t, record.Name, "example.com")
ip := net.ParseIP(record.RData)
if ip != nil {
if !ip.IsPrivate() {
require.NotEqual(t, "203.0.113.10", record.RData)
} else {
require.Equal(t, "192.168.1.10", record.RData)
}
}
}
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
return cgnat.Contains(ip)
}

View File

@@ -1,7 +0,0 @@
//go:build !windows
package server
func handlePanicLog() error {
return nil
}

View File

@@ -1,83 +0,0 @@
package server
import (
"fmt"
"os"
"path/filepath"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
const (
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
stdErrorHandle = ^uintptr(11)
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
setStdHandleFn = kernel32.NewProc("SetStdHandle")
)
func handlePanicLog() error {
logPath := os.Getenv(windowsPanicLogEnvVar)
if logPath == "" {
return nil
}
// Ensure the directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0750); err != nil {
return fmt.Errorf("create panic log directory: %w", err)
}
if err := util.EnforcePermission(logPath); err != nil {
return fmt.Errorf("enforce permission on panic log file: %w", err)
}
// Open log file with append mode
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("open panic log file: %w", err)
}
// Redirect stderr to the file
if err = redirectStderr(f); err != nil {
if closeErr := f.Close(); closeErr != nil {
log.Warnf("failed to close file after redirect error: %v", closeErr)
}
return fmt.Errorf("redirect stderr: %w", err)
}
log.Infof("successfully configured panic logging to: %s", logPath)
return nil
}
// redirectStderr redirects stderr to the provided file
func redirectStderr(f *os.File) error {
// Get the current process's stderr handle
if err := setStdHandle(f); err != nil {
return fmt.Errorf("failed to set stderr handle: %w", err)
}
// Also set os.Stderr for Go's standard library
os.Stderr = f
return nil
}
func setStdHandle(f *os.File) error {
handle := f.Fd()
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
if r0 == 0 {
if e1 != nil {
return e1
}
return syscall.EINVAL
}
return nil
}

View File

@@ -68,8 +68,6 @@ type Server struct {
relayProbe *internal.Probe
wgProbe *internal.Probe
lastProbe time.Time
persistNetworkMap bool
}
type oauthAuthFlow struct {
@@ -99,10 +97,6 @@ func (s *Server) Start() error {
defer s.mutex.Unlock()
state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil {
log.Warnf("failed to redirect stderr: %v", err)
}
if err := restoreResidualState(s.rootCtx); err != nil {
log.Warnf(errRestoreResidualState, err)
}
@@ -198,7 +192,6 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetNetworkMapPersistence(s.persistNetworkMap)
probes := internal.ProbeHolder{
MgmProbe: s.mgmProbe,
@@ -629,8 +622,6 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
s.mutex.Lock()
defer s.mutex.Unlock()
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil {
return nil, fmt.Errorf("service is not up")
}

View File

@@ -5,112 +5,12 @@ import (
"fmt"
"github.com/hashicorp/go-multierror"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/proto"
)
// ListStates returns a list of all saved states
func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) {
mgr := statemanager.New(statemanager.GetDefaultStatePath())
stateNames, err := mgr.GetSavedStateNames()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get saved state names: %v", err)
}
states := make([]*proto.State, 0, len(stateNames))
for _, name := range stateNames {
states = append(states, &proto.State{
Name: name,
})
}
return &proto.ListStatesResponse{
States: states,
}, nil
}
// CleanState handles cleaning of states (performing cleanup operations)
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
if req.All {
// Reuse existing cleanup logic for all states
if err := restoreResidualState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err)
}
// Get count of cleaned states
mgr := statemanager.New(statemanager.GetDefaultStatePath())
stateNames, err := mgr.GetSavedStateNames()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err)
}
return &proto.CleanStateResponse{
CleanedStates: int32(len(stateNames)),
}, nil
}
// Handle single state cleanup
mgr := statemanager.New(statemanager.GetDefaultStatePath())
registerStates(mgr)
if err := mgr.CleanupStateByName(req.StateName); err != nil {
return nil, status.Errorf(codes.Internal, "failed to clean state %s: %v", req.StateName, err)
}
if err := mgr.PersistState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
}
return &proto.CleanStateResponse{
CleanedStates: 1,
}, nil
}
// DeleteState handles deletion of states without cleanup
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
}
mgr := statemanager.New(statemanager.GetDefaultStatePath())
var count int
var err error
if req.All {
count, err = mgr.DeleteAllStates()
} else {
err = mgr.DeleteStateByName(req.StateName)
if err == nil {
count = 1
}
}
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete state: %v", err)
}
// Persist the changes
if err := mgr.PersistState(ctx); err != nil {
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
}
return &proto.DeleteStateResponse{
DeletedStates: int32(count),
}, nil
}
// restoreResidualState checks if the client was not shut down in a clean way and restores residual if required.
// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
func restoreResidualState(ctx context.Context) error {
path := statemanager.GetDefaultStatePath()
@@ -124,7 +24,6 @@ func restoreResidualState(ctx context.Context) error {
registerStates(mgr)
var merr *multierror.Error
if err := mgr.PerformCleanup(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
}

View File

@@ -61,14 +61,6 @@ type Info struct {
Files []File // for posture checks
}
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)

View File

@@ -10,12 +10,13 @@ import (
"os/exec"
"runtime"
"strings"
"time"
"golang.org/x/sys/unix"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version"
)
@@ -40,10 +41,11 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
start := time.Now()
si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
serialNum, prodName, manufacturer := sysInfo()
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
gio := &Info{
@@ -55,10 +57,10 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(),
KernelVersion: release,
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
}
systemHostname, _ := os.Hostname()

View File

@@ -1,4 +1,5 @@
//go:build !android
// +build !android
package system
@@ -15,13 +16,30 @@ import (
log "github.com/sirupsen/logrus"
"github.com/zcalusic/sysinfo"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version"
)
var (
// it is override in tests
getSystemInfo = defaultSysInfoImplementation
)
type SysInfoGetter interface {
GetSysInfo() SysInfo
}
type SysInfoWrapper struct {
si sysinfo.SysInfo
}
func (s SysInfoWrapper) GetSysInfo() SysInfo {
s.si.GetSysInfo()
return SysInfo{
ChassisSerial: s.si.Chassis.Serial,
ProductSerial: s.si.Product.Serial,
BoardSerial: s.si.Board.Serial,
ProductName: s.si.Product.Name,
BoardName: s.si.Board.Name,
ProductVendor: s.si.Product.Vendor,
}
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
@@ -47,10 +65,12 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
start := time.Now()
si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
si := SysInfoWrapper{}
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
gio := &Info{
@@ -65,10 +85,10 @@ func GetInfo(ctx context.Context) *Info {
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
}
return gio
@@ -88,9 +108,9 @@ func _getInfo() string {
return out.String()
}
func sysInfo() (string, string, string) {
func sysInfo(si SysInfo) (string, string, string) {
isascii := regexp.MustCompile("^[[:ascii:]]+$")
si := getSystemInfo()
serials := []string{si.ChassisSerial, si.ProductSerial}
serial := ""
@@ -121,16 +141,3 @@ func sysInfo() (string, string, string) {
}
return serial, name, manufacturer
}
func defaultSysInfoImplementation() SysInfo {
si := sysinfo.SysInfo{}
si.GetSysInfo()
return SysInfo{
ChassisSerial: si.Chassis.Serial,
ProductSerial: si.Product.Serial,
BoardSerial: si.Board.Serial,
ProductName: si.Product.Name,
BoardName: si.Board.Name,
ProductVendor: si.Product.Vendor,
}
}

View File

@@ -6,12 +6,13 @@ import (
"os"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
"github.com/netbirdio/netbird/version"
)
@@ -41,10 +42,24 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
start := time.Now()
si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
serialNum, err := sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
prodName, err := sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err := sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
env := Environment{
Cloud: detect_cloud.Detect(ctx),
Platform: detect_platform.Detect(ctx),
}
gio := &Info{
@@ -56,10 +71,10 @@ func GetInfo(ctx context.Context) *Info {
CPUs: runtime.NumCPU(),
KernelVersion: buildVersion,
NetworkAddresses: addrs,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
SystemSerialNumber: serialNum,
SystemProductName: prodName,
SystemManufacturer: manufacturer,
Environment: env,
}
systemHostname, _ := os.Hostname()
@@ -70,26 +85,6 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")

View File

@@ -1,46 +0,0 @@
//go:build (linux && !android) || windows || (darwin && !ios)
package system
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
var (
staticInfo StaticInfo
once sync.Once
)
func init() {
go func() {
_ = updateStaticInfo()
}()
}
func updateStaticInfo() StaticInfo {
once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
})
return staticInfo
}

View File

@@ -183,10 +183,7 @@ func Test_sysInfo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
getSystemInfo = func() SysInfo {
return tt.sysInfo
}
gotSerialNum, gotProdName, gotManufacturer := sysInfo()
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
if gotSerialNum != tt.wantSerialNum {
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
}

View File

@@ -1,126 +0,0 @@
{
"version": "v1.0.0",
"entity": {
"type": "organisation",
"role": "owner",
"name": "NetBird GmbH",
"email": "hello@netbird.io",
"phone": "",
"description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.",
"webpageUrl": {
"url": "https://github.com/netbirdio"
}
},
"projects": [
{
"guid": "netbird",
"name": "NetBird",
"description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.",
"webpageUrl": {
"url": "https://github.com/netbirdio/netbird"
},
"repositoryUrl": {
"url": "https://github.com/netbirdio/netbird"
},
"licenses": [
"BSD-3"
],
"tags": [
"network-security",
"vpn",
"developer-tools",
"ztna",
"zero-trust",
"remote-access",
"wireguard",
"peer-to-peer",
"private-networking",
"software-defined-networking"
]
}
],
"funding": {
"channels": [
{
"guid": "github-sponsors",
"type": "payment-provider",
"address": "https://github.com/sponsors/netbirdio",
"description": ""
},
{
"guid": "bank-transfer",
"type": "bank",
"address": "",
"description": "Contact us at hello@netbird.io for bank transfer details."
}
],
"plans": [
{
"guid": "support-yearly",
"status": "active",
"name": "Support Open Source Development and Maintenance - Yearly",
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
"amount": 100000,
"currency": "USD",
"frequency": "yearly",
"channels": [
"github-sponsors",
"bank-transfer"
]
},
{
"guid": "support-one-time-year",
"status": "active",
"name": "Support Open Source Development and Maintenance - One Year",
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
"amount": 100000,
"currency": "USD",
"frequency": "one-time",
"channels": [
"github-sponsors",
"bank-transfer"
]
},
{
"guid": "support-one-time-monthly",
"status": "active",
"name": "Support Open Source Development and Maintenance - Monthly",
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
"amount": 10000,
"currency": "USD",
"frequency": "monthly",
"channels": [
"github-sponsors",
"bank-transfer"
]
},
{
"guid": "support-monthly",
"status": "active",
"name": "Support Open Source Development and Maintenance - One Month",
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
"amount": 10000,
"currency": "USD",
"frequency": "monthly",
"channels": [
"github-sponsors",
"bank-transfer"
]
},
{
"guid": "goodwill",
"status": "active",
"name": "Goodwill Plan",
"description": "Pay anything you wish to show your goodwill for the project.",
"amount": 0,
"currency": "USD",
"frequency": "monthly",
"channels": [
"github-sponsors",
"bank-transfer"
]
}
],
"history": null
}
}

15
go.mod
View File

@@ -25,7 +25,7 @@ require (
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.1
google.golang.org/protobuf v1.34.2
google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
@@ -71,6 +71,7 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/r3labs/diff/v3 v3.0.1
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@@ -155,7 +156,7 @@ require (
github.com/go-text/typesetting v0.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
@@ -210,6 +211,8 @@ require (
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect
@@ -224,11 +227,11 @@ require (
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
k8s.io/apimachinery v0.26.2 // indirect
)
@@ -236,7 +239,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

30
go.sum
View File

@@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
@@ -521,14 +521,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -1151,8 +1157,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 h1:AgADTJarZTBqgjiUzRgfaBchgYB3/WFTC80GPwsMcRI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -1189,8 +1195,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -1232,8 +1238,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@@ -530,7 +530,7 @@ renderCaddyfile() {
{
debug
servers :80,:443 {
protocols h1 h2c h3
protocols h1 h2c
}
}
@@ -788,7 +788,6 @@ services:
networks: [ netbird ]
ports:
- '443:443'
- '443:443/udp'
- '80:80'
- '8080:8080'
volumes:

View File

@@ -110,10 +110,11 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@@ -139,7 +140,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager
@@ -965,9 +966,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
}
// UserGroupsAddToPeers adds groups to all peers of user
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
userPeers := make(map[string]struct{})
for pid, peer := range a.Peers {
if peer.UserID == userID {
@@ -981,8 +980,6 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[stri
continue
}
oldPeers := group.Peers
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
@@ -996,25 +993,16 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[stri
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
groupUpdates[gid] = difference(group.Peers, oldPeers)
}
return groupUpdates
}
// UserGroupsRemoveFromPeers removes groups from all peers of user
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
groupUpdates := make(map[string][]string)
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok || group.Name == "All" {
continue
}
oldPeers := group.Peers
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]
@@ -1026,10 +1014,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
}
}
group.Peers = update
groupUpdates[gid] = difference(oldPeers, group.Peers)
}
return groupUpdates
}
// BuildManager creates a new DefaultAccountManager with a provided Store
@@ -1191,11 +1176,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
@@ -1206,39 +1186,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil
}
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
if newSettings.GroupsPropagationEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
// Todo: retroactively add user groups to all peers
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
}
}
return nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
return nil
@@ -1287,7 +1249,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", accountID)
log.Errorf("failed getting account %s expiring peers", account.Id)
return account.GetNextInactivePeerExpiration()
}
@@ -1473,7 +1435,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
if err != nil {
return err
}
@@ -2067,7 +2029,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user: %w", err)
}
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
groups, err := transaction.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -2097,7 +2059,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
groups, err = transaction.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -2121,7 +2083,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@@ -2139,7 +2101,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@@ -2152,7 +2114,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@@ -2165,19 +2127,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.GroupsPropagationEnabled {
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
return fmt.Errorf("error getting account: %w", err)
}
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
if err != nil {
return err
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, accountID)
am.updateAccountPeers(ctx, account)
}
}
@@ -2333,12 +2290,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, status.NewGetAccountError(err)
return nil, nil, nil, err
}
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
return nil, nil, nil, err
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
@@ -2357,12 +2314,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return status.NewGetAccountError(err)
return err
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return nil
@@ -2378,9 +2335,6 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
@@ -2444,7 +2398,12 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
am.updateAccountPeers(ctx, accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@@ -6,17 +6,13 @@ import (
b64 "encoding/base64"
"encoding/json"
"fmt"
"io"
"net"
"os"
"reflect"
"strconv"
"sync"
"testing"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -33,18 +29,14 @@ import (
)
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
@@ -986,110 +978,6 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := jwtclaims.AuthorizationClaims{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
}
publicClaims := jwtclaims.AuthorizationClaims{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: PublicCategory,
}
am, err := createManager(b)
if err != nil {
b.Fatal(err)
return
}
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
users := genUsers("priv", 100)
acc, err := am.Store.GetAccount(context.Background(), id)
if err != nil {
b.Fatal(err)
}
acc.Users = users
err = am.Store.SaveAccount(context.Background(), acc)
if err != nil {
b.Fatal(err)
}
userP := genUsers("pub", 100)
pacc, err := am.Store.GetAccount(context.Background(), pid)
if err != nil {
b.Fatal(err)
}
pacc.Users = userP
err = am.Store.SaveAccount(context.Background(), pacc)
if err != nil {
b.Fatal(err)
}
b.Run("public without account ID", func(b *testing.B) {
// b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private without account ID", func(b *testing.B) {
// b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
b.Run("private with account ID", func(b *testing.B) {
claims.AccountId = id
// b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
if err != nil {
b.Fatal(err)
}
}
})
}
func genUsers(p string, n int) map[string]*User {
users := map[string]*User{}
now := time.Now()
for i := 0; i < n; i++ {
users[fmt.Sprintf("%s-%d", p, i)] = &User{
Id: fmt.Sprintf("%s-%d", p, i),
Role: UserRoleAdmin,
LastLogin: now,
CreatedAt: now,
Issued: "api",
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
}
}
return users
}
func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -1242,7 +1130,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return
}
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policy := Policy{
ID: "policy",
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1253,7 +1142,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
})
}
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1322,6 +1212,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1334,19 +1237,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
}()
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1367,7 +1258,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1378,8 +1269,9 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err)
return
}
@@ -1413,20 +1305,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
group := group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
require.NoError(t, err, "failed to save group")
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1437,8 +1322,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
})
if err != nil {
}
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
t.Errorf("save policy: %v", err)
return
}
@@ -1461,7 +1352,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
@@ -2715,7 +2606,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2735,7 +2626,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2774,7 +2665,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
@@ -2986,218 +2877,3 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
t.Error("Timed out waiting for update message")
}
}
func BenchmarkSyncAndMarkPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 1, 3, 4, 10},
{"Medium", 500, 100, 7, 13, 10, 60},
{"Large", 5000, 200, 65, 80, 60, 170},
{"Small single", 50, 10, 1, 3, 4, 60},
{"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 170},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 102, 110, 102, 120},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 160, 200, 160, 270},
{"Small single", 50, 10, 102, 110, 102, 120},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 160, 200, 160, 270},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
// We need different expectations for CI/CD and local runs because of the different performance characteristics
minMsPerOpLocal float64
maxMsPerOpLocal float64
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 107, 120, 107, 140},
{"Medium", 500, 100, 105, 140, 105, 170},
{"Large", 5000, 200, 180, 220, 180, 320},
{"Small single", 50, 10, 107, 120, 105, 140},
{"Medium single", 500, 10, 105, 140, 105, 170},
{"Large 5", 5000, 15, 180, 220, 180, 320},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
UserID: "regular_user",
SetupKey: "",
ConnectionIP: net.IP{1, 1, 1, 1},
})
assert.NoError(b, err)
}
duration := time.Since(start)
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}

View File

@@ -148,9 +148,6 @@ const (
AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68
UserGroupPropagationEnabled Activity = 69
UserGroupPropagationDisabled Activity = 70
)
var activityMap = map[Activity]Code{
@@ -225,9 +222,6 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
}
// StringCode returns a string code of the activity

View File

@@ -0,0 +1,82 @@
package differs
import (
"fmt"
"net/netip"
"reflect"
"github.com/r3labs/diff/v3"
)
// NetIPAddr is a custom differ for netip.Addr
type NetIPAddr struct {
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
}
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
}
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
if a.Kind() == reflect.Invalid {
cl.Add(diff.CREATE, path, nil, b.Interface())
return nil
}
if b.Kind() == reflect.Invalid {
cl.Add(diff.DELETE, path, a.Interface(), nil)
return nil
}
fromAddr, ok1 := a.Interface().(netip.Addr)
toAddr, ok2 := b.Interface().(netip.Addr)
if !ok1 || !ok2 {
return fmt.Errorf("invalid type for netip.Addr")
}
if fromAddr.String() != toAddr.String() {
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
}
return nil
}
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
differ.DiffFunc = dfunc //nolint
}
// NetIPPrefix is a custom differ for netip.Prefix
type NetIPPrefix struct {
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
}
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
}
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
if a.Kind() == reflect.Invalid {
cl.Add(diff.CREATE, path, nil, b.Interface())
return nil
}
if b.Kind() == reflect.Invalid {
cl.Add(diff.DELETE, path, a.Interface(), nil)
return nil
}
fromPrefix, ok1 := a.Interface().(netip.Prefix)
toPrefix, ok2 := b.Interface().(netip.Prefix)
if !ok1 || !ok2 {
return fmt.Errorf("invalid type for netip.Addr")
}
if fromPrefix.String() != toPrefix.String() {
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
}
return nil
}
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
differ.DiffFunc = dfunc //nolint
}

View File

@@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
"slices"
"strconv"
"sync"
@@ -86,12 +85,8 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
}
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -99,137 +94,64 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.NewAdminPermissionError()
}
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
if err != nil {
return err
}
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account)
}
return nil
}
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
}
for _, groupID := range addedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
})
}
for _, groupID := range removedGroups {
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
continue
}
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
})
}
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}
return validateGroups(settings.DisabledManagementGroups, groups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{

View File

@@ -8,10 +8,9 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -522,64 +521,23 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
}
})
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
t.Run("creating dns setting with unused groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupB"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
// Creating DNS settings with groups that have peers should update account peers and send peer update
t.Run("creating dns setting with used groups", func(t *testing.T) {
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
// Saving DNS settings with groups that have peers should update account peers and send peer update
t.Run("saving dns setting with used groups", func(t *testing.T) {
@@ -601,6 +559,27 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
}
})
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
// since there is no change in the network map
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupA", "groupB"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
done := make(chan struct{})

View File

@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
err := util.WriteJson(context.Background(), file, s)
err := util.WriteJson(file, s)
if err != nil {
return err
}

View File

@@ -6,11 +6,10 @@ import (
"fmt"
"slices"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@@ -28,17 +27,18 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
return nil
@@ -49,7 +49,8 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
}
// GetAllGroups returns all groups in an account
@@ -57,12 +58,13 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
}
// SaveGroup object of the peers
@@ -75,74 +77,79 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var eventsToStore []func()
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
newGroup.ID = xid.New().String()
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
})
if err != nil {
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
}
newGroupIDs := make([]string, 0, len(newGroups))
for _, newGroup := range newGroups {
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
if oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
@@ -152,42 +159,35 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
})
}
modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
for _, peerID := range removedPeers {
peer, ok := peers[peerID]
if !ok {
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
}
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
}
@@ -210,10 +210,42 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock()
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
}
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
return nil
}
// DeleteGroups deletes groups from an account.
@@ -222,94 +254,93 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
})
if err != nil {
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
}
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
}
return nil
@@ -320,162 +351,90 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
})
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
return nil
}
// validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
// Prevent duplicate groups for API-issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
}
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
executingUser := account.Users[userID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -487,13 +446,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
@@ -501,18 +454,11 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
@@ -522,13 +468,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
@@ -537,36 +477,8 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
@@ -575,18 +487,21 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s
return false
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
if err != nil {
return false, err
}
for _, group := range groups {
if group.HasPeers() {
return true, nil
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
}
}
return false, nil
return false
}

View File

@@ -49,35 +49,3 @@ func (g *Group) Copy() *Group {
func (g *Group) HasPeers() bool {
return len(g.Peers) > 0
}
// IsGroupAll checks if the group is a default "All" group.
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
// AddPeer adds peerID to Peers if not present, returning true if added.
func (g *Group) AddPeer(peerID string) bool {
if peerID == "" {
return false
}
for _, itemID := range g.Peers {
if itemID == peerID {
return false
}
}
g.Peers = append(g.Peers, peerID)
return true
}
// RemovePeer removes peerID from Peers if present, returning true if removed.
func (g *Group) RemovePeer(peerID string) bool {
for i, itemID := range g.Peers {
if itemID == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
return true
}
}
return false
}

View File

@@ -1,90 +0,0 @@
package group
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddPeer(t *testing.T) {
t.Run("add new peer to empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add new peer to non-empty slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.True(t, group.AddPeer(peerID))
assert.Contains(t, group.Peers, peerID)
})
t.Run("add duplicate peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer1"
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("add empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.AddPeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}
func TestRemovePeer(t *testing.T) {
t.Run("remove existing peer from slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
peerID := "peer2"
assert.True(t, group.RemovePeer(peerID))
assert.NotContains(t, group.Peers, peerID)
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from empty slice", func(t *testing.T) {
group := &Group{Peers: []string{}}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
})
t.Run("remove peer from nil slice", func(t *testing.T) {
group := &Group{Peers: nil}
peerID := "peer1"
assert.False(t, group.RemovePeer(peerID))
assert.Nil(t, group.Peers)
})
t.Run("remove non-existent peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := "peer3"
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
t.Run("remove peer from single-item slice", func(t *testing.T) {
group := &Group{Peers: []string{"peer1"}}
peerID := "peer1"
assert.True(t, group.RemovePeer(peerID))
assert.Equal(t, 0, len(group.Peers))
assert.NotContains(t, group.Peers, peerID)
})
t.Run("remove empty peer", func(t *testing.T) {
group := &Group{Peers: []string{"peer1", "peer2"}}
peerID := ""
assert.False(t, group.RemovePeer(peerID))
assert.Equal(t, 2, len(group.Peers))
})
}

View File

@@ -8,13 +8,12 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@@ -208,7 +207,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedReasons: []string{"group: non-existent-group not found"},
expectedDeleted: []string{"non-existent-group"},
},
{
name: "delete multiple groups with mixed results",
@@ -500,7 +499,8 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})
// adding a group to policy
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true,
Rules: []*PolicyRule{
{
@@ -511,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
})
}, false)
assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update
@@ -536,6 +536,29 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
})
// Saving an unchanged group should trigger account peers update and not send peer update
// since there is no change in the network map
t.Run("saving unchanged group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// adding peer to a used group should update account peers and send peer update
t.Run("adding peer to linked group", func(t *testing.T) {
done := make(chan struct{})

View File

@@ -6,7 +6,6 @@ import (
"net"
"net/netip"
"strings"
"sync"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -39,7 +38,6 @@ type GRPCServer struct {
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
}
// NewServer creates a new Management server
@@ -150,13 +148,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
@@ -180,7 +171,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err)
}
@@ -200,15 +190,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
unlock()
unlock = nil
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for {
select {
// condition when there are some updates
@@ -259,18 +245,10 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
@@ -296,24 +274,6 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return claims.UserId, nil
}
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now()
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
start = time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
}
return unlock
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {

View File

@@ -439,13 +439,17 @@ components:
example: 5
required:
- accessible_peers_count
SetupKeyBase:
SetupKey:
type: object
properties:
id:
description: Setup Key ID
type: string
example: 2531583362
key:
description: Setup Key value
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
name:
description: Setup key name identifier
type: string
@@ -514,31 +518,22 @@ components:
- updated_at
- usage_limit
- ephemeral
SetupKeyClear:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as plain text
type: string
example: A616097E-FCF0-48FA-9354-CA4A61142761
required:
- key
SetupKey:
allOf:
- $ref: '#/components/schemas/SetupKeyBase'
- type: object
properties:
key:
description: Setup Key as secret
type: string
example: A6160****
required:
- key
SetupKeyRequest:
type: object
properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds, 0 will mean the key never expires
type: integer
minimum: 0
example: 86400
revoked:
description: Setup key revocation status
type: boolean
@@ -549,9 +544,21 @@ components:
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required:
- name
- type
- expires_in
- revoked
- auto_groups
- usage_limit
CreateSetupKeyRequest:
type: object
properties:
@@ -1936,7 +1943,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/SetupKeyClear'
$ref: '#/components/schemas/SetupKey'
'400':
"$ref": "#/components/responses/bad_request"
'401':

View File

@@ -1062,94 +1062,7 @@ type SetupKey struct {
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key as secret
Key string `json:"key"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyBase defines model for SetupKeyBase.
type SetupKeyBase struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"`
// Name Setup key name identifier
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
// UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"`
// Valid Setup key validity status
Valid bool `json:"valid"`
}
// SetupKeyClear defines model for SetupKeyClear.
type SetupKeyClear struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral bool `json:"ephemeral"`
// Expires Setup Key expiration date
Expires time.Time `json:"expires"`
// Id Setup Key ID
Id string `json:"id"`
// Key Setup Key as plain text
// Key Setup Key value
Key string `json:"key"`
// LastUsed Setup key last usage date
@@ -1185,8 +1098,23 @@ type SetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
// Ephemeral Indicate that the peer will be ephemeral or not
Ephemeral *bool `json:"ephemeral,omitempty"`
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
ExpiresIn int `json:"expires_in"`
// Name Setup Key name
Name string `json:"name"`
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
// Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"`
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
UsageLimit int `json:"usage_limit"`
}
// User defines model for User.

View File

@@ -184,26 +184,14 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupsMap := map[string]*nbgroup.Group{}
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range groups {
groupsMap[group.ID] = group
}
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range account.Peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
@@ -316,7 +304,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
groupsInfo := []api.GroupMinimum{}
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
for _, group := range groups {
_, ok := groupsChecked[group.ID]

View File

@@ -6,8 +6,10 @@ import (
"strconv"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -120,22 +122,21 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
policy := &server.Policy{
isUpdate := policyID != ""
if policyID == "" {
policyID = xid.New().String()
}
policy := server.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
ruleID = *rule.Id
}
pr := server.PolicyRule{
ID: ruleID,
PolicyID: policyID,
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
Destinations: rule.Destinations,
Sources: rule.Sources,
@@ -224,8 +225,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks
}
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
if err != nil {
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
resp := toPolicyResponse(allGroups, policy)
resp := toPolicyResponse(allGroups, &policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return

View File

@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return policy, nil
return nil
},
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil

View File

@@ -169,8 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
if err != nil {
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return p, nil
},
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
return postureChecks, nil
return nil
},
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]

View File

@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
}
if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
}
if req.Peer != nil && req.PeerGroups != nil {

View File

@@ -137,6 +137,11 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
return
}
if req.Name == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return
@@ -145,6 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)

View File

@@ -52,23 +52,26 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a)
}
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
if len(groups) == 0 {
return true, nil
}
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
}
return nil
})
accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil {
return false, err
}
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
}
}
if !found {
return false, nil
}
}
return true, nil
}

View File

@@ -11,7 +11,7 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)

View File

@@ -77,8 +77,6 @@ type JWTValidator struct {
options Options
}
var keyNotFound = errors.New("unable to find appropriate key")
// NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation)
@@ -126,18 +124,12 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
}
publicKey, err := getPublicKey(ctx, token, keys)
if err == nil {
return publicKey, nil
if err != nil {
log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
return nil, err
}
msg := fmt.Sprintf("getPublicKey error: %s", err)
if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled {
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
}
log.WithContext(ctx).Error(msg)
return nil, err
return publicKey, nil
},
EnableAuthOnOptions: false,
}
@@ -237,7 +229,7 @@ func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
}
return nil, keyNotFound
return nil, errors.New("unable to find appropriate key")
}
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
@@ -318,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0
}

Some files were not shown because too many files have changed in this diff Show More