mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-02 07:33:52 -04:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ac5e9d866 | ||
|
|
954e038da0 | ||
|
|
9ccc6c6547 | ||
|
|
2a3262f5a8 |
3
.github/FUNDING.yml
vendored
3
.github/FUNDING.yml
vendored
@@ -1,3 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: [netbirdio]
|
||||
42
.github/workflows/golang-test-linux.yml
vendored
42
.github/workflows/golang-test-linux.yml
vendored
@@ -13,7 +13,6 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
@@ -52,47 +51,6 @@ jobs:
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
benchmark:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||
|
||||
test_client_on_docker:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
||||
@@ -17,12 +17,8 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
@@ -34,7 +30,7 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const anonTLD = ".domain"
|
||||
|
||||
type Anonymizer struct {
|
||||
ipAnonymizer map[netip.Addr]netip.Addr
|
||||
domainAnonymizer map[string]string
|
||||
@@ -85,39 +83,29 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
||||
baseDomain := domain
|
||||
hasDot := strings.HasSuffix(domain, ".")
|
||||
if hasDot {
|
||||
baseDomain = domain[:len(domain)-1]
|
||||
}
|
||||
|
||||
if strings.HasSuffix(baseDomain, "netbird.io") ||
|
||||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
|
||||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
|
||||
strings.HasSuffix(baseDomain, "netbird.stage") ||
|
||||
strings.HasSuffix(baseDomain, anonTLD) {
|
||||
if strings.HasSuffix(domain, "netbird.io") ||
|
||||
strings.HasSuffix(domain, "netbird.selfhosted") ||
|
||||
strings.HasSuffix(domain, "netbird.cloud") ||
|
||||
strings.HasSuffix(domain, "netbird.stage") ||
|
||||
strings.HasSuffix(domain, ".domain") {
|
||||
return domain
|
||||
}
|
||||
|
||||
parts := strings.Split(baseDomain, ".")
|
||||
parts := strings.Split(domain, ".")
|
||||
if len(parts) < 2 {
|
||||
return domain
|
||||
}
|
||||
|
||||
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||
|
||||
anonymized, ok := a.domainAnonymizer[baseForLookup]
|
||||
anonymized, ok := a.domainAnonymizer[baseDomain]
|
||||
if !ok {
|
||||
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
|
||||
a.domainAnonymizer[baseForLookup] = anonymizedBase
|
||||
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
|
||||
a.domainAnonymizer[baseDomain] = anonymizedBase
|
||||
anonymized = anonymizedBase
|
||||
}
|
||||
|
||||
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
|
||||
if hasDot {
|
||||
result += "."
|
||||
}
|
||||
return result
|
||||
return strings.Replace(domain, baseDomain, anonymized, 1)
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||
@@ -164,9 +152,9 @@ func (a *Anonymizer) AnonymizeString(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
|
||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
|
||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
||||
|
||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||
}
|
||||
@@ -180,10 +168,10 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||
parts := strings.Split(match, `"`)
|
||||
if len(parts) >= 2 {
|
||||
domain := parts[1]
|
||||
if strings.HasSuffix(domain, anonTLD) {
|
||||
if strings.HasSuffix(domain, ".domain") {
|
||||
return match
|
||||
}
|
||||
randomDomain := generateRandomString(10) + anonTLD
|
||||
randomDomain := generateRandomString(10) + ".domain"
|
||||
return strings.Replace(match, domain, randomDomain, 1)
|
||||
}
|
||||
return match
|
||||
@@ -213,8 +201,6 @@ func isWellKnown(addr netip.Addr) bool {
|
||||
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
|
||||
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
|
||||
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
|
||||
|
||||
"128.0.0.0", "8000::", // 2nd split subnet for default routes
|
||||
}
|
||||
|
||||
if slices.Contains(wellKnown, addr.String()) {
|
||||
|
||||
@@ -67,36 +67,18 @@ func TestAnonymizeDomain(t *testing.T) {
|
||||
`^anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Domain with Trailing Dot",
|
||||
"example.com.",
|
||||
`^anon-[a-zA-Z0-9]+\.domain.$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Subdomain",
|
||||
"sub.example.com",
|
||||
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Subdomain with Trailing Dot",
|
||||
"sub.example.com.",
|
||||
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Protected Domain",
|
||||
"netbird.io",
|
||||
`^netbird\.io$`,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"Protected Domain with Trailing Dot",
|
||||
"netbird.io.",
|
||||
`^netbird\.io.$`,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -158,16 +140,8 @@ func TestAnonymizeSchemeURI(t *testing.T) {
|
||||
expect string
|
||||
}{
|
||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
|
||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
|
||||
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
|
||||
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
|
||||
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -3,7 +3,6 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -62,15 +61,6 @@ var forCmd = &cobra.Command{
|
||||
RunE: runForDuration,
|
||||
}
|
||||
|
||||
var persistenceCmd = &cobra.Command{
|
||||
Use: "persistence [on|off]",
|
||||
Short: "Set network map memory persistence",
|
||||
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
|
||||
Example: " netbird debug persistence on",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: setNetworkMapPersistence,
|
||||
}
|
||||
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
@@ -181,13 +171,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Enable network map persistence before bringing the service up
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -217,13 +200,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Disable network map persistence after creating the debug bundle
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
@@ -243,34 +219,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
persistence := strings.ToLower(args[0])
|
||||
if persistence != "on" && persistence != "off" {
|
||||
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
|
||||
}
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: persistence == "on",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Printf("Network map persistence set to: %s\n", persistence)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
addr := pprofAddr()
|
||||
go pprof(addr)
|
||||
}
|
||||
|
||||
func pprofAddr() string {
|
||||
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||
if listenAddr == "" {
|
||||
return "localhost:6969"
|
||||
}
|
||||
|
||||
return listenAddr
|
||||
}
|
||||
|
||||
func pprof(listenAddr string) {
|
||||
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
log.Fatalf("Failed to start pprof: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -155,7 +155,6 @@ func init() {
|
||||
debugCmd.AddCommand(logCmd)
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||
`Sets external IPs maps between local addresses and interfaces.`+
|
||||
|
||||
@@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -14,11 +13,10 @@ import (
|
||||
)
|
||||
|
||||
type program struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
serverInstanceMu sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
serv *grpc.Server
|
||||
serverInstance *server.Server
|
||||
}
|
||||
|
||||
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
|
||||
@@ -61,9 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstanceMu.Lock()
|
||||
p.serverInstance = serverInstance
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
@@ -74,7 +72,6 @@ func (p *program) Start(svc service.Service) error {
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
in := new(proto.DownRequest)
|
||||
_, err := p.serverInstance.Down(p.ctx, in)
|
||||
@@ -82,7 +79,6 @@ func (p *program) Stop(srv service.Service) error {
|
||||
log.Errorf("failed to stop daemon: %v", err)
|
||||
}
|
||||
}
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
p.cancel()
|
||||
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
allFlag bool
|
||||
)
|
||||
|
||||
var stateCmd = &cobra.Command{
|
||||
Use: "state",
|
||||
Short: "Manage daemon state",
|
||||
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
|
||||
}
|
||||
|
||||
var stateListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List all stored states",
|
||||
Long: "Lists all registered states with their status and basic information.",
|
||||
Example: " netbird state list",
|
||||
RunE: stateList,
|
||||
}
|
||||
|
||||
var stateCleanCmd = &cobra.Command{
|
||||
Use: "clean [state-name]",
|
||||
Short: "Clean stored states",
|
||||
Long: `Clean specific state or all states. The daemon must not be running.
|
||||
This will perform cleanup operations and remove the state.`,
|
||||
Example: ` netbird state clean dns_state
|
||||
netbird state clean --all`,
|
||||
RunE: stateClean,
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check mutual exclusivity between --all flag and state-name argument
|
||||
if allFlag && len(args) > 0 {
|
||||
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||
}
|
||||
if !allFlag && len(args) != 1 {
|
||||
return fmt.Errorf("requires a state name argument or --all flag")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var stateDeleteCmd = &cobra.Command{
|
||||
Use: "delete [state-name]",
|
||||
Short: "Delete stored states",
|
||||
Long: `Delete specific state or all states from storage. The daemon must not be running.
|
||||
This will remove the state without performing any cleanup operations.`,
|
||||
Example: ` netbird state delete dns_state
|
||||
netbird state delete --all`,
|
||||
RunE: stateDelete,
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check mutual exclusivity between --all flag and state-name argument
|
||||
if allFlag && len(args) > 0 {
|
||||
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||
}
|
||||
if !allFlag && len(args) != 1 {
|
||||
return fmt.Errorf("requires a state name argument or --all flag")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(stateCmd)
|
||||
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
|
||||
|
||||
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
|
||||
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
|
||||
}
|
||||
|
||||
func stateList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
cmd.Printf("\nStored states:\n\n")
|
||||
for _, state := range resp.States {
|
||||
cmd.Printf("- %s\n", state.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stateClean(cmd *cobra.Command, args []string) error {
|
||||
var stateName string
|
||||
if !allFlag {
|
||||
stateName = args[0]
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
|
||||
StateName: stateName,
|
||||
All: allFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if resp.CleanedStates == 0 {
|
||||
cmd.Println("No states were cleaned")
|
||||
return nil
|
||||
}
|
||||
|
||||
if allFlag {
|
||||
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
|
||||
} else {
|
||||
cmd.Printf("Successfully cleaned state %q\n", stateName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func stateDelete(cmd *cobra.Command, args []string) error {
|
||||
var stateName string
|
||||
if !allFlag {
|
||||
stateName = args[0]
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
|
||||
StateName: stateName,
|
||||
All: allFlag,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if resp.DeletedStates == 0 {
|
||||
cmd.Println("No states were deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
if allFlag {
|
||||
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
|
||||
} else {
|
||||
cmd.Printf("Successfully deleted state %q\n", stateName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
nameEval := false
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
@@ -700,13 +700,11 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
|
||||
@@ -352,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
|
||||
m.optionalEntries["PREROUTING"] = []entry{
|
||||
{
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
|
||||
position: 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -83,11 +83,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,24 +18,22 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4Nat = "netbird-rt-nat"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpPre = "jump-pre"
|
||||
jumpNat = "jump-nat"
|
||||
matchSet = "--match-set"
|
||||
)
|
||||
|
||||
@@ -325,25 +323,24 @@ func (r *router) Reset() error {
|
||||
}
|
||||
|
||||
func (r *router) cleanUpDefaultForwardRules() error {
|
||||
if err := r.cleanJumpRules(); err != nil {
|
||||
return fmt.Errorf("clean jump rules: %w", err)
|
||||
err := r.cleanJumpRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTPRE, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
ok, err := r.iptablesClient.ChainExists(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
log.Errorf("failed check chain %s, error: %v", chain, err)
|
||||
return err
|
||||
} else if ok {
|
||||
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
|
||||
if err != nil {
|
||||
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -352,16 +349,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
} {
|
||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
for _, chain := range []string{chainRTFWD, chainRTNAT} {
|
||||
if err := r.createAndSetupChain(chain); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chain, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,10 +359,6 @@ func (r *router) createContainers() error {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add static nat rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
@@ -380,32 +366,6 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
"!", "-o", "lo",
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||
return fmt.Errorf("add outbound masquerade rule: %v", err)
|
||||
}
|
||||
r.rules["static-nat-outbound"] = rule1
|
||||
|
||||
// Second rule for return traffic masquerade
|
||||
rule2 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
"-o", r.wgIface.Name(),
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||
return fmt.Errorf("add return masquerade rule: %v", err)
|
||||
}
|
||||
r.rules["static-nat-return"] = rule2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
@@ -417,14 +377,10 @@ func (r *router) createAndSetupChain(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
switch chain {
|
||||
case chainRTNAT:
|
||||
if chain == chainRTNAT {
|
||||
return tableNat
|
||||
case chainRTPRE:
|
||||
return tableMangle
|
||||
default:
|
||||
return tableFilter
|
||||
}
|
||||
return tableFilter
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
@@ -442,39 +398,25 @@ func (r *router) insertEstablishedRule(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
// Jump to NAT chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat jump rule: %v", err)
|
||||
rule := []string{"-j", chainRTNAT}
|
||||
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.rules[jumpNat] = natRule
|
||||
|
||||
// Jump to prerouting chain
|
||||
preRule := []string{"-j", chainRTPRE}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpPre] = preRule
|
||||
r.rules[ipv4Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
table := tableNat
|
||||
chain := chainPOSTROUTING
|
||||
if ruleKey == jumpPre {
|
||||
table = tableMangle
|
||||
chain = chainPREROUTING
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
rule, found := r.rules[ipv4Nat]
|
||||
if found {
|
||||
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -482,35 +424,19 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
rule := []string{"-i", r.wgIface.Name()}
|
||||
if pair.Inverse {
|
||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||
}
|
||||
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", pair.Source.String(),
|
||||
"-d", pair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -518,12 +444,13 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
|
||||
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
log.Debugf("nat rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -555,6 +482,16 @@ func (r *router) updateState() {
|
||||
}
|
||||
}
|
||||
|
||||
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
|
||||
intdir := "-i"
|
||||
lointdir := "-o"
|
||||
if inverse {
|
||||
intdir = "-o"
|
||||
lointdir = "-i"
|
||||
}
|
||||
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
|
||||
@@ -3,18 +3,17 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func isIptablesSupported() bool {
|
||||
@@ -35,24 +34,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
|
||||
// Now 5 rules:
|
||||
// 1. established rule in forward chain
|
||||
// 2. jump rule to NAT chain
|
||||
// 3. jump rule to PRE chain
|
||||
// 4. static outbound masquerade rule
|
||||
// 5. static return masquerade rule
|
||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||
require.Len(t, manager.rules, 2, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
require.True(t, exists, "postrouting jump rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||
require.True(t, exists, "prerouting jump rule should exist")
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
@@ -60,15 +49,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "adding NAT rule should not return error")
|
||||
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
|
||||
|
||||
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
}
|
||||
|
||||
func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
@@ -83,66 +79,52 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
defer func() {
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
err := manager.Reset()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset iptables manager: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "marking rule should be inserted")
|
||||
require.NoError(t, err, "forwarding pair should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", testCase.InputPair.Source.String(),
|
||||
"-d", testCase.InputPair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "marking rule should be created")
|
||||
foundRule, found := manager.rules[natRuleKey]
|
||||
require.True(t, found, "marking rule should exist in the map")
|
||||
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
|
||||
require.True(t, exists, "nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[natRuleKey]
|
||||
require.True(t, foundNat, "nat rule should exist in the map")
|
||||
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "marking rule should not be created")
|
||||
_, found := manager.rules[natRuleKey]
|
||||
require.False(t, found, "marking rule should not exist in the map")
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[natRuleKey]
|
||||
require.False(t, foundNat, "nat rule should not exist in the map")
|
||||
}
|
||||
|
||||
// Check inverse rule
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", inversePair.Source.String(),
|
||||
"-d", inversePair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "inverse marking rule should be created")
|
||||
foundRule, found := manager.rules[inverseRuleKey]
|
||||
require.True(t, found, "inverse marking rule should exist in the map")
|
||||
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
|
||||
require.True(t, exists, "income nat rule should be created")
|
||||
foundNatRule, foundNat := manager.rules[inNatRuleKey]
|
||||
require.True(t, foundNat, "income nat rule should exist in the map")
|
||||
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
|
||||
} else {
|
||||
require.False(t, exists, "inverse marking rule should not be created")
|
||||
_, found := manager.rules[inverseRuleKey]
|
||||
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||
require.False(t, exists, "nat rule should not be created")
|
||||
_, foundNat := manager.rules[inNatRuleKey]
|
||||
require.False(t, foundNat, "income nat rule should not exist in the map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
if !isIptablesSupported() {
|
||||
t.SkipNow()
|
||||
}
|
||||
@@ -155,52 +137,42 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
_ = manager.Reset()
|
||||
}()
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule without error")
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
|
||||
|
||||
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", testCase.InputPair.Source.String(),
|
||||
"-d", testCase.InputPair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
require.False(t, exists, "marking rule should not exist")
|
||||
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "nat rule should not exist")
|
||||
|
||||
_, found := manager.rules[natRuleKey]
|
||||
require.False(t, found, "marking rule should not exist in the manager map")
|
||||
require.False(t, found, "nat rule should exist in the manager map")
|
||||
|
||||
// Check inverse rule removal
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", inversePair.Source.String(),
|
||||
"-d", inversePair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark",
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
|
||||
require.False(t, exists, "income nat rule should not exist")
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
require.False(t, exists, "inverse marking rule should not exist")
|
||||
|
||||
_, found = manager.rules[inverseRuleKey]
|
||||
require.False(t, found, "inverse marking rule should not exist in the map")
|
||||
_, found = manager.rules[inNatRuleKey]
|
||||
require.False(t, found, "income nat rule should exist in the manager map")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,11 +37,6 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
|
||||
if temp.IPs == nil {
|
||||
temp.IPs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -94,10 +89,5 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
|
||||
if temp.IPSets == nil {
|
||||
temp.IPSets = make(map[string]*ipList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
|
||||
@@ -520,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
@@ -543,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
|
||||
@@ -99,11 +99,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -199,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@@ -276,7 +274,7 @@ func (m *Manager) resetNetbirdInputRules() error {
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
@@ -351,9 +349,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
&expr.Verdict{},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -227,105 +225,3 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runIptablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("iptables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err, "iptables-save failed to run")
|
||||
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
// Check for any incompatibility warnings
|
||||
require.NotContains(t,
|
||||
stderr,
|
||||
"incompatible",
|
||||
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||
stderr,
|
||||
)
|
||||
|
||||
// Verify standard tables are present
|
||||
expectedTables := []string{
|
||||
"*filter",
|
||||
"*nat",
|
||||
"*mangle",
|
||||
}
|
||||
|
||||
for _, table := range expectedTables {
|
||||
require.Contains(t,
|
||||
stdout,
|
||||
table,
|
||||
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||
table,
|
||||
stdout,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"test rule",
|
||||
)
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -125,6 +124,7 @@ func (r *router) createContainers() error {
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
@@ -133,21 +133,6 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Chain is created by acl manager
|
||||
// TODO: move creation to a common place
|
||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: r.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
@@ -437,149 +422,59 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
op := expr.CmpOpEq
|
||||
dir := expr.MetaKeyIIFNAME
|
||||
notDir := expr.MetaKeyOIFNAME
|
||||
if pair.Inverse {
|
||||
op = expr.CmpOpNeq
|
||||
dir = expr.MetaKeyOIFNAME
|
||||
notDir = expr.MetaKeyIIFNAME
|
||||
}
|
||||
|
||||
lo := ifname("lo")
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
exprs := []expr.Any{
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
&expr.Meta{
|
||||
Key: dir,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: notDir,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
|
||||
// interface matching
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: op,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
Data: lo,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Counter{}, &expr.Masq{},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNamePrerouting],
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addPostroutingRules adds the masquerade rules
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
})
|
||||
|
||||
// Second masquerade rule for traffic going out through WireGuard interface
|
||||
exprs2 := []expr.Any{
|
||||
// Match on the second fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
},
|
||||
|
||||
// Match WireGuard interface
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -828,18 +723,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
// RemoveNatRule removes a nftables rule pair from nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -854,20 +749,21 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||
log.Debugf("nftables: nat rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -33,87 +32,100 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
manager, err := newRouter(table, ifaceMock)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
require.NoError(t, manager.init(table))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
rtr := manager.router
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
|
||||
})
|
||||
defer func(manager *router, pair firewall.RouterPair) {
|
||||
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
|
||||
}(manager, testCase.InputPair)
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
// Build expected expressions for connection tracking
|
||||
conntrackExprs := []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
|
||||
// Build interface matching expression
|
||||
ifaceExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
}
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
testingExpression := append(conntrackExprs, ifaceExprs...)
|
||||
testingExpression = append(testingExpression, sourceExp...)
|
||||
testingExpression = append(testingExpression, destExp...)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
// Compare expressions up to the mark setting expressions
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
if testCase.InputPair.Masquerade {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
|
||||
testingExpression = append(testingExpression,
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(ifaceMock.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
)
|
||||
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
found := 0
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
|
||||
found = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, found, "should find at least 1 rule to test")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -123,66 +135,68 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
table, err := createWorkTable()
|
||||
require.NoError(t, err, "Failed to create work table")
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
manager, err := newRouter(table, ifaceMock)
|
||||
require.NoError(t, err, "failed to create router")
|
||||
require.NoError(t, manager.init(table))
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer func(manager *router) {
|
||||
require.NoError(t, manager.Reset(), "failed to reset rules")
|
||||
}(manager)
|
||||
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
|
||||
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(natRuleKey),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
rtr := manager.router
|
||||
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
|
||||
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
|
||||
|
||||
// First add the NAT rule using the router's method
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
|
||||
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
|
||||
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, found, "NAT rule should exist before removal")
|
||||
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
|
||||
Table: manager.workTable,
|
||||
Chain: manager.chains[chainNameRoutingNat],
|
||||
Exprs: natExp,
|
||||
UserData: []byte(inNatRuleKey),
|
||||
})
|
||||
|
||||
// Now remove the rule
|
||||
err = rtr.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error when removing rule")
|
||||
err = nftablesTestingClient.Flush()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, found, "NAT rule should not exist after removal")
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
// Verify the static postrouting rules still exist
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
|
||||
require.NoError(t, err, "should list postrouting rules")
|
||||
foundCounter := false
|
||||
for _, rule := range rules {
|
||||
for _, e := range rule.Exprs {
|
||||
if _, ok := e.(*expr.Counter); ok {
|
||||
foundCounter = true
|
||||
break
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
for _, chain := range manager.chains {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
|
||||
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
|
||||
}
|
||||
}
|
||||
if foundCounter {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundCounter, "static postrouting rule should remain")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
1
client/firewall/nftables/state.go
Normal file
1
client/firewall/nftables/state.go
Normal file
@@ -0,0 +1 @@
|
||||
package nftables
|
||||
@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
return errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
@@ -25,8 +24,8 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -155,7 +154,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
@@ -167,30 +166,16 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := getMessages(msgsPool)
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer putMessages(msgs, msgsPool)
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||
//nolint
|
||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
@@ -206,12 +191,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
@@ -289,15 +273,3 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
@@ -162,13 +162,12 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -95,10 +94,7 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
}
|
||||
return nil
|
||||
return p.remoteConn.Close()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
@@ -108,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
buf := make([]byte, 1500)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
return util.WriteJson(path, config)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,8 +40,6 @@ type ConnectClient struct {
|
||||
statusRecorder *peer.Status
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
|
||||
persistNetworkMap bool
|
||||
}
|
||||
|
||||
func NewConnectClient(
|
||||
@@ -159,8 +157,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
defer func() {
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
||||
c.statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
@@ -210,8 +207,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
c.statusRecorder.MarkSignalDisconnected(nil)
|
||||
defer func() {
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkSignalDisconnected(err)
|
||||
c.statusRecorder.MarkSignalDisconnected(state.err)
|
||||
}()
|
||||
|
||||
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
|
||||
@@ -234,7 +230,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
if token != nil {
|
||||
if err := relayManager.UpdateToken(token); err != nil {
|
||||
@@ -245,7 +240,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
||||
if err = relayManager.Serve(); err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
}
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
@@ -260,7 +257,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
|
||||
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := c.engine.Start(); err != nil {
|
||||
@@ -338,19 +335,6 @@ func (c *ConnectClient) Engine() *Engine {
|
||||
return e
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
return StatusIdle
|
||||
}
|
||||
status, err := CtxGetState(c.ctx).Status()
|
||||
if err != nil {
|
||||
return StatusIdle
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
@@ -377,22 +361,6 @@ func (c *ConnectClient) isContextCancelled() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||
// When enabled, the last received network map will be stored and can be retrieved
|
||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||
// network map will be cleared. This functionality is primarily used for debugging
|
||||
// and should not be enabled during normal operation.
|
||||
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
|
||||
c.engineMutex.Lock()
|
||||
c.persistNetworkMap = enabled
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
engine.SetNetworkMapPersistence(enabled)
|
||||
}
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
|
||||
nm := false
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
@@ -322,12 +323,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
@@ -532,11 +533,12 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
|
||||
@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"google.com"},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"github.com/pion/stun/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -40,6 +38,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -172,11 +171,7 @@ type Engine struct {
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
|
||||
// Network map persistence
|
||||
persistNetworkMap bool
|
||||
latestNetworkMap *mgmProto.NetworkMap
|
||||
srWatcher *guard.SRWatcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -302,7 +297,7 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
@@ -354,17 +349,8 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(
|
||||
e.ctx,
|
||||
e.config.WgPrivateKey.PublicKey().String(),
|
||||
e.config.DNSRouteInterval,
|
||||
e.wgInterface,
|
||||
e.statusRecorder,
|
||||
e.relayManager,
|
||||
initialRoutes,
|
||||
e.stateManager,
|
||||
)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
@@ -552,7 +538,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
|
||||
relayMsg := wCfg.GetRelay()
|
||||
if relayMsg != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
c := &auth.Token{
|
||||
Payload: relayMsg.GetTokenPayload(),
|
||||
Signature: relayMsg.GetTokenSignature(),
|
||||
@@ -561,16 +546,9 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
log.Errorf("failed to update relay token: %v", err)
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
_ = e.relayManager.Serve()
|
||||
} else {
|
||||
e.relayManager.UpdateServerURLs(nil)
|
||||
}
|
||||
|
||||
// todo update relay address in the relay manager
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
@@ -578,22 +556,13 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return err
|
||||
}
|
||||
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
if update.GetNetworkMap() != nil {
|
||||
// only apply new changes and ignore old ones
|
||||
err := e.updateNetworkMap(update.GetNetworkMap())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store network map if persistence is enabled
|
||||
if e.persistNetworkMap {
|
||||
e.latestNetworkMap = nm
|
||||
log.Debugf("network map persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -672,10 +641,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
}
|
||||
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
oldAddr := e.wgInterface.Address().String()
|
||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||
@@ -1514,59 +1479,8 @@ func (e *Engine) stopDNSServer() {
|
||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||
}
|
||||
|
||||
// SetNetworkMapPersistence enables or disables network map persistence
|
||||
func (e *Engine) SetNetworkMapPersistence(enabled bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if enabled == e.persistNetworkMap {
|
||||
return
|
||||
}
|
||||
e.persistNetworkMap = enabled
|
||||
log.Debugf("Network map persistence is set to %t", enabled)
|
||||
|
||||
if !enabled {
|
||||
e.latestNetworkMap = nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetLatestNetworkMap returns the stored network map if persistence is enabled
|
||||
func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if !e.persistNetworkMap {
|
||||
return nil, errors.New("network map persistence is disabled")
|
||||
}
|
||||
|
||||
if e.latestNetworkMap == nil {
|
||||
//nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create a deep copy to avoid external modifications
|
||||
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
|
||||
if !ok {
|
||||
|
||||
return nil, fmt.Errorf("failed to clone network map")
|
||||
}
|
||||
|
||||
return nm, nil
|
||||
}
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
sort.Slice(check.Files, func(i, j int) bool {
|
||||
return check.Files[i] < check.Files[j]
|
||||
})
|
||||
}
|
||||
for _, oCheck := range oChecks {
|
||||
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||
return oCheck.Files[i] < oCheck.Files[j]
|
||||
})
|
||||
}
|
||||
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
|
||||
@@ -245,15 +245,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
nil)
|
||||
|
||||
wgIface := &iface.MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
RemovePeerFunc: func(peerKey string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
@@ -1009,99 +1006,6 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckFilesEqual(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputChecks1 []*mgmtProto.Checks
|
||||
inputChecks2 []*mgmtProto.Checks
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "Equal Files In Equal Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Equal Files In Reverse Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile2",
|
||||
"testfile1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Unequal Files Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile3",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "Compared With Empty Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -83,6 +83,7 @@ type Conn struct {
|
||||
signaler *Signaler
|
||||
relayManager *relayClient.Manager
|
||||
allowedIP net.IP
|
||||
allowedNet string
|
||||
handshaker *Handshaker
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
@@ -110,7 +111,7 @@ type Conn struct {
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
|
||||
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return nil, err
|
||||
@@ -128,6 +129,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
allowedIP: allowedIP,
|
||||
allowedNet: allowedNet.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
}
|
||||
@@ -307,11 +309,6 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
return
|
||||
}
|
||||
|
||||
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
|
||||
conn.log.Errorf("remote ICE connection is nil")
|
||||
return
|
||||
}
|
||||
|
||||
conn.log.Debugf("ICE connection is ready")
|
||||
|
||||
if conn.currentConnPriority > priority {
|
||||
@@ -336,9 +333,12 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
|
||||
ep = wgProxy.EndpointAddr()
|
||||
conn.wgProxyICE = wgProxy
|
||||
} else {
|
||||
conn.log.Infof("direct iceConnInfo: %v", iceConnInfo.RemoteConn)
|
||||
agentCheck(conn.log, iceConnInfo.Agent)
|
||||
nilCheck(conn.log, iceConnInfo.RemoteConn)
|
||||
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolveUDPaddr")
|
||||
conn.log.Errorf("failed to resolveUDPaddr")
|
||||
conn.handleConfigurationFailure(err, nil)
|
||||
return
|
||||
}
|
||||
@@ -440,7 +440,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
if conn.iceP2PIsActive() {
|
||||
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
@@ -463,7 +463,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
|
||||
wgConfigWorkaround()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.wgProxyRelay = wgProxy
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
conn.log.Infof("start to communicate with peer via relay")
|
||||
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
|
||||
@@ -592,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr)
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -734,15 +734,6 @@ func (conn *Conn) logTraceConnState() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
|
||||
if conn.wgProxyRelay != nil {
|
||||
if err := conn.wgProxyRelay.CloseConn(); err != nil {
|
||||
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
|
||||
}
|
||||
}
|
||||
conn.wgProxyRelay = proxy
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -2,20 +2,53 @@ package peer
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
|
||||
func nilCheck(log *log.Entry, conn net.Conn) {
|
||||
if conn == nil {
|
||||
log.Errorf("ice conn is nil")
|
||||
return true
|
||||
log.Infof("conn is nil")
|
||||
return
|
||||
}
|
||||
|
||||
if conn.RemoteAddr() == nil {
|
||||
log.Errorf("ICE remote address is nil")
|
||||
return true
|
||||
log.Infof("conn.RemoteAddr() is nil")
|
||||
return
|
||||
}
|
||||
|
||||
return false
|
||||
if reflect.ValueOf(conn.RemoteAddr()).IsNil() {
|
||||
log.Infof("value of conn.RemoteAddr() is nil")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func agentCheck(log *log.Entry, agent *ice.Agent) {
|
||||
if agent == nil {
|
||||
log.Errorf("agent is nil")
|
||||
return
|
||||
}
|
||||
|
||||
pair, err := agent.GetSelectedCandidatePair()
|
||||
if err != nil {
|
||||
log.Errorf("error getting selected candidate pair: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if pair == nil {
|
||||
log.Errorf("pair is nil")
|
||||
return
|
||||
}
|
||||
|
||||
if pair.Remote == nil {
|
||||
log.Errorf("pair.Remote is nil")
|
||||
return
|
||||
}
|
||||
|
||||
if pair.Remote.Address() == "" {
|
||||
log.Errorf("address is empty")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func (s *State) DeleteRoute(network string) {
|
||||
func (s *State) GetRoutes() map[string]struct{} {
|
||||
s.Mux.RLock()
|
||||
defer s.Mux.RUnlock()
|
||||
return maps.Clone(s.routes)
|
||||
return s.routes
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -237,6 +237,10 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
peerState.IP = receivedState.IP
|
||||
}
|
||||
|
||||
if receivedState.GetRoutes() != nil {
|
||||
peerState.SetRoutes(receivedState.GetRoutes())
|
||||
}
|
||||
|
||||
skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
|
||||
|
||||
if receivedState.ConnStatus != peerState.ConnStatus {
|
||||
@@ -257,40 +261,12 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
peerState.AddRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peer]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.DeleteRoute(route)
|
||||
d.peers[peer] = peerState
|
||||
|
||||
// todo: consider to make sense of this notification or not
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -325,7 +301,12 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -353,7 +334,12 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -380,7 +366,12 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -410,7 +401,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||
ch, found := d.changeNotify[receivedState.PubKey]
|
||||
if found && ch != nil {
|
||||
close(ch)
|
||||
d.changeNotify[receivedState.PubKey] = nil
|
||||
}
|
||||
|
||||
d.notifyPeerListChanged()
|
||||
return nil
|
||||
}
|
||||
@@ -481,14 +477,11 @@ func (d *Status) FinishPeerListModifications() {
|
||||
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
ch, found := d.changeNotify[peer]
|
||||
if found {
|
||||
return ch
|
||||
if !found || ch == nil {
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
}
|
||||
|
||||
ch = make(chan struct{})
|
||||
d.changeNotify[peer] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
@@ -676,23 +669,25 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
// extend the list of stun, turn servers with relay address
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
|
||||
var relayState relay.ProbeResult
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
})
|
||||
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
})
|
||||
}
|
||||
return relayStates
|
||||
}
|
||||
return relayStates
|
||||
relayState.Err = err
|
||||
}
|
||||
|
||||
relayState := relay.ProbeResult{
|
||||
URI: instanceAddr,
|
||||
}
|
||||
relayState.URI = instanceAddr
|
||||
return append(relayStates, relayState)
|
||||
}
|
||||
|
||||
@@ -760,17 +755,6 @@ func (d *Status) onConnectionChanged() {
|
||||
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||
}
|
||||
|
||||
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||
ch, found := d.changeNotify[peerID]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(d.changeNotify, peerID)
|
||||
}
|
||||
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(d.numOfPeers())
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
|
||||
|
||||
peerState.IP = ip
|
||||
|
||||
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
|
||||
err := status.UpdatePeerState(peerState)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
select {
|
||||
|
||||
@@ -29,6 +29,7 @@ type ICEConnInfo struct {
|
||||
LocalIceCandidateEndpoint string
|
||||
Relayed bool
|
||||
RelayedOnLocal bool
|
||||
Agent *ice.Agent
|
||||
}
|
||||
|
||||
type WorkerICECallbacks struct {
|
||||
@@ -46,6 +47,8 @@ type WorkerICE struct {
|
||||
hasRelayOnLocally bool
|
||||
conn WorkerICECallbacks
|
||||
|
||||
selectedPriority ConnPriority
|
||||
|
||||
agent *ice.Agent
|
||||
muxAgent sync.Mutex
|
||||
|
||||
@@ -55,9 +58,6 @@ type WorkerICE struct {
|
||||
|
||||
localUfrag string
|
||||
localPwd string
|
||||
|
||||
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
|
||||
lastKnownState ice.ConnectionState
|
||||
}
|
||||
|
||||
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
|
||||
@@ -93,8 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
|
||||
var preferredCandidateTypes []ice.CandidateType
|
||||
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
||||
w.selectedPriority = connPriorityICEP2P
|
||||
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
||||
} else {
|
||||
w.selectedPriority = connPriorityICETurn
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
@@ -125,6 +127,9 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.log.Debugf("failed to dial the remote peer: %s", err)
|
||||
return
|
||||
}
|
||||
w.log.Infof("check remoteConn: %v", remoteConn)
|
||||
w.log.Infof("check remoteConn.RemoteAddr: %v", remoteConn.RemoteAddr())
|
||||
nilCheck(w.log, remoteConn)
|
||||
w.log.Debugf("agent dial succeeded")
|
||||
|
||||
pair, err := w.agent.GetSelectedCandidatePair()
|
||||
@@ -153,9 +158,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
||||
Relayed: isRelayed(pair),
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
Agent: agent,
|
||||
}
|
||||
w.log.Debugf("on ICE conn read to use ready")
|
||||
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
||||
go w.conn.OnConnReady(w.selectedPriority, ci)
|
||||
}
|
||||
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
@@ -193,7 +199,8 @@ func (w *WorkerICE) Close() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := w.agent.Close(); err != nil {
|
||||
err := w.agent.Close()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
}
|
||||
@@ -213,18 +220,15 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
|
||||
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
|
||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||
switch state {
|
||||
case ice.ConnectionStateConnected:
|
||||
w.lastKnownState = ice.ConnectionStateConnected
|
||||
return
|
||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
||||
if w.lastKnownState != ice.ConnectionStateDisconnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
}
|
||||
w.closeAgent(agentCancel)
|
||||
default:
|
||||
return
|
||||
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
|
||||
w.conn.OnStatusChanged(StatusDisconnected)
|
||||
|
||||
w.muxAgent.Lock()
|
||||
agentCancel()
|
||||
_ = agent.Close()
|
||||
w.agent = nil
|
||||
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
@@ -250,17 +254,6 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
|
||||
return agent, nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
cancel()
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||
// wait local endpoint configuration
|
||||
time.Sleep(time.Second)
|
||||
@@ -334,8 +327,10 @@ func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
isControlling := w.config.LocalKey > w.config.Key
|
||||
if isControlling {
|
||||
w.log.Infof("dialing remote peer %s as controlling", w.config.Key)
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
w.log.Infof("dialing remote peer %s as controlled", w.config.Key)
|
||||
return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
}
|
||||
@@ -390,11 +385,3 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
||||
if isRelayed(pair) {
|
||||
return connPriorityICETurn
|
||||
} else {
|
||||
return connPriorityICEP2P
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,20 +122,13 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
tempScore = float64(metricDiff) * 10
|
||||
}
|
||||
|
||||
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
|
||||
latency := 999 * time.Millisecond
|
||||
// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
|
||||
latency := time.Second
|
||||
if peerStatus.latency != 0 {
|
||||
latency = peerStatus.latency
|
||||
} else {
|
||||
log.Tracef("peer %s has 0 latency, range %s", r.Peer, c.handler)
|
||||
log.Warnf("peer %s has 0 latency", r.Peer)
|
||||
}
|
||||
|
||||
// avoid negative tempScore on the higher latency calculation
|
||||
if latency > 1*time.Second {
|
||||
latency = 999 * time.Millisecond
|
||||
}
|
||||
|
||||
// higher latency is worse score
|
||||
tempScore += 1 - latency.Seconds()
|
||||
|
||||
if !peerStatus.relayed {
|
||||
@@ -157,8 +150,6 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)
|
||||
|
||||
switch {
|
||||
case chosen == "":
|
||||
var peers []string
|
||||
@@ -204,20 +195,15 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
|
||||
func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
for _, r := range c.routes {
|
||||
_, found := c.routePeersNotifiers[r.Peer]
|
||||
if found {
|
||||
continue
|
||||
if !found {
|
||||
c.routePeersNotifiers[r.Peer] = make(chan struct{})
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
|
||||
}
|
||||
|
||||
closerChan := make(chan struct{})
|
||||
c.routePeersNotifiers[r.Peer] = closerChan
|
||||
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
|
||||
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer() error {
|
||||
c.removeStateRoute()
|
||||
|
||||
if err := c.handler.RemoveAllowedIPs(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs: %w", err)
|
||||
@@ -232,7 +218,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
|
||||
}
|
||||
if err := c.handler.RemoveRoute(); err != nil {
|
||||
@@ -271,7 +257,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
}
|
||||
} else {
|
||||
// Otherwise, remove the allowed IPs from the previous peer first
|
||||
if err := c.removeRouteFromWireGuardPeer(); err != nil {
|
||||
if err := c.removeRouteFromWireguardPeer(); err != nil {
|
||||
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
}
|
||||
@@ -282,13 +268,37 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
|
||||
}
|
||||
|
||||
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("add peer state route: %w", err)
|
||||
}
|
||||
c.addStateRoute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientNetwork) addStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.AddRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) removeStateRoute() {
|
||||
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get peer state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
state.DeleteRoute(c.handler.String())
|
||||
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
|
||||
log.Warnf("Failed to update peer state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
|
||||
go func() {
|
||||
c.routeUpdate <- update
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -228,64 +227,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "relayed routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "p2p routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "current route with bad score should be changed to route with better score",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
@@ -346,45 +287,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// fill the test data with random routes
|
||||
for _, tc := range testCases {
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
|
||||
@@ -32,7 +32,7 @@ import (
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
|
||||
TriggerSelection(route.HAMap)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
@@ -59,7 +59,6 @@ type DefaultManager struct {
|
||||
routeRefCounter *refcounter.RouteRefCounter
|
||||
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
|
||||
dnsRouteInterval time.Duration
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func NewManager(
|
||||
@@ -70,7 +69,6 @@ func NewManager(
|
||||
statusRecorder *peer.Status,
|
||||
relayMgr *relayClient.Manager,
|
||||
initialRoutes []*route.Route,
|
||||
stateManager *statemanager.Manager,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
notifier := notifier.NewNotifier()
|
||||
@@ -82,12 +80,12 @@ func NewManager(
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
relayMgr: relayMgr,
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: notifier,
|
||||
stateManager: stateManager,
|
||||
}
|
||||
|
||||
dm.routeRefCounter = refcounter.New(
|
||||
@@ -123,7 +121,7 @@ func NewManager(
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -139,38 +137,14 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
m.routeSelector = m.initSelector()
|
||||
|
||||
log.Info("Routing setup complete")
|
||||
return beforePeerHook, afterPeerHook, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
var state *SelectorState
|
||||
m.stateManager.RegisterState(state)
|
||||
|
||||
// restore selector state if it exists
|
||||
if err := m.stateManager.LoadState(state); err != nil {
|
||||
log.Warnf("failed to load state: %v", err)
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
if state := m.stateManager.GetState(state); state != nil {
|
||||
if selector, ok := state.(*SelectorState); ok {
|
||||
return (*routeselector.RouteSelector)(selector)
|
||||
}
|
||||
|
||||
log.Warnf("failed to convert state with type %T to SelectorState", state)
|
||||
}
|
||||
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
var err error
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
@@ -278,10 +252,6 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
|
||||
@@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
_, _, err = routeManager.Init(nil)
|
||||
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop(nil)
|
||||
|
||||
@@ -21,7 +21,7 @@ type MockManager struct {
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -47,9 +47,10 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
|
||||
type Counter[Key comparable, I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for keys
|
||||
refCountMap map[Key]Ref[O]
|
||||
mu sync.Mutex
|
||||
refCountMu sync.Mutex
|
||||
// idMap keeps track of the keys associated with an ID for removal
|
||||
idMap map[string][]Key
|
||||
idMu sync.Mutex
|
||||
add AddFunc[Key, I, O]
|
||||
remove RemoveFunc[Key, O]
|
||||
}
|
||||
@@ -71,14 +72,13 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
||||
}
|
||||
|
||||
// LoadData loads the data from the existing counter
|
||||
// The passed counter should not be used any longer after calling this function.
|
||||
func (rm *Counter[Key, I, O]) LoadData(
|
||||
existingCounter *Counter[Key, I, O],
|
||||
) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
existingCounter.mu.Lock()
|
||||
defer existingCounter.mu.Unlock()
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
rm.refCountMap = existingCounter.refCountMap
|
||||
rm.idMap = existingCounter.idMap
|
||||
@@ -87,8 +87,8 @@ func (rm *Counter[Key, I, O]) LoadData(
|
||||
// Get retrieves the current reference count and associated data for a key.
|
||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[key]
|
||||
return ref, ok
|
||||
@@ -97,13 +97,9 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
// Increment increments the reference count for the given key.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
return rm.increment(key, in)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||
ref := rm.refCountMap[key]
|
||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||
|
||||
@@ -130,10 +126,10 @@ func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
ref, err := rm.increment(key, in)
|
||||
ref, err := rm.Increment(key, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
@@ -145,12 +141,9 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
|
||||
// Decrement decrements the reference count for the given key.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
return rm.decrement(key)
|
||||
}
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
|
||||
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||
ref, ok := rm.refCountMap[key]
|
||||
if !ok {
|
||||
logCallerF("No reference found for key %v", key)
|
||||
@@ -175,12 +168,12 @@ func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, key := range rm.idMap[id] {
|
||||
if _, err := rm.decrement(key); err != nil {
|
||||
if _, err := rm.Decrement(key); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
@@ -191,8 +184,10 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each key.
|
||||
func (rm *Counter[Key, I, O]) Flush() error {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for key := range rm.refCountMap {
|
||||
@@ -211,8 +206,10 @@ func (rm *Counter[Key, I, O]) Flush() error {
|
||||
|
||||
// Clear removes all references without calling RemoveFunc.
|
||||
func (rm *Counter[Key, I, O]) Clear() {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
|
||||
clear(rm.refCountMap)
|
||||
clear(rm.idMap)
|
||||
@@ -220,9 +217,6 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
@@ -234,9 +228,6 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var temp struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
IDMap map[string][]Key `json:"idMap"`
|
||||
@@ -247,13 +238,6 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
|
||||
rm.refCountMap = temp.RefCountMap
|
||||
rm.idMap = temp.IDMap
|
||||
|
||||
if temp.RefCountMap == nil {
|
||||
temp.RefCountMap = map[Key]Ref[O]{}
|
||||
}
|
||||
if temp.IDMap == nil {
|
||||
temp.IDMap = map[string][]Key{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
)
|
||||
|
||||
type SelectorState routeselector.RouteSelector
|
||||
|
||||
func (s *SelectorState) Name() string {
|
||||
return "routeselector_state"
|
||||
}
|
||||
|
||||
func (s *SelectorState) MarshalJSON() ([]byte, error) {
|
||||
return (*routeselector.RouteSelector)(s).MarshalJSON()
|
||||
}
|
||||
|
||||
func (s *SelectorState) UnmarshalJSON(data []byte) error {
|
||||
return (*routeselector.RouteSelector)(s).UnmarshalJSON(data)
|
||||
}
|
||||
@@ -2,28 +2,31 @@ package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
type ShutdownState ExclusionCounter
|
||||
type ShutdownState struct {
|
||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "route_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.Counter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysops := NewSysOps(nil, nil)
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
sysops.refCounter.LoadData(s.Counter)
|
||||
|
||||
return sysops.refCounter.Flush()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
return (*ExclusionCounter)(s).MarshalJSON()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||
return (*ExclusionCounter)(s).UnmarshalJSON(data)
|
||||
}
|
||||
|
||||
@@ -57,19 +57,30 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nexthop, err
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
// remove from state even if we have trouble removing it from the route table
|
||||
// it could be already gone
|
||||
r.updateState(stateManager)
|
||||
|
||||
return r.removeFromRouteTable(prefix, nexthop)
|
||||
},
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses, stateManager)
|
||||
return r.setupHooks(initAddresses)
|
||||
}
|
||||
|
||||
// updateState updates state on every change so it will be persisted regularly
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
|
||||
state := getState(stateManager)
|
||||
|
||||
state.Counter = r.refCounter
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -325,7 +336,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
@@ -336,8 +347,6 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
@@ -345,8 +354,6 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -525,3 +532,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
||||
|
||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||
var shutdownState *ShutdownState
|
||||
if state := stateManager.GetState(shutdownState); state != nil {
|
||||
shutdownState = state.(*ShutdownState)
|
||||
} else {
|
||||
shutdownState = &ShutdownState{}
|
||||
}
|
||||
|
||||
return shutdownState
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ type ruleParams struct {
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
@@ -92,6 +92,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
@@ -112,17 +123,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
}
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
||||
rule.Invert = params.invert
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("add routing rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
||||
rule.Priority = params.priority
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -230,13 +230,10 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI
|
||||
if idx != 0 {
|
||||
intf, err := net.InterfaceByIndex(idx)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get interface name for index %d: %v", idx, err)
|
||||
update.Interface = &net.Interface{
|
||||
Index: idx,
|
||||
}
|
||||
} else {
|
||||
update.Interface = intf
|
||||
return update, fmt.Errorf("get interface name: %w", err)
|
||||
}
|
||||
|
||||
update.Interface = intf
|
||||
}
|
||||
|
||||
log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package routeselector
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -14,7 +12,6 @@ import (
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
mu sync.RWMutex
|
||||
selectedRoutes map[route.NetID]struct{}
|
||||
selectAll bool
|
||||
}
|
||||
@@ -29,9 +26,6 @@ func NewRouteSelector() *RouteSelector {
|
||||
|
||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||
func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if !appendRoute {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
@@ -52,9 +46,6 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
func (rs *RouteSelector) SelectAllRoutes() {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
rs.selectAll = true
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
@@ -62,9 +53,6 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.selectAll {
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
@@ -88,18 +76,12 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.selectAll {
|
||||
return true
|
||||
}
|
||||
@@ -109,9 +91,6 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.selectAll {
|
||||
return maps.Clone(routes)
|
||||
}
|
||||
@@ -124,49 +103,3 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
}{
|
||||
SelectAll: rs.selectAll,
|
||||
SelectedRoutes: rs.selectedRoutes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||
// If the JSON is empty or null, it will initialize like a NewRouteSelector.
|
||||
func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
// Check for null or empty JSON
|
||||
if len(data) == 0 || string(data) == "null" {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
rs.selectAll = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var temp struct {
|
||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||
SelectAll bool `json:"select_all"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rs.selectedRoutes = temp.SelectedRoutes
|
||||
rs.selectAll = temp.SelectAll
|
||||
|
||||
if rs.selectedRoutes == nil {
|
||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,39 +16,14 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
errStateNotRegistered = "state %s not registered"
|
||||
errLoadStateFile = "load state file: %w"
|
||||
)
|
||||
|
||||
// State interface defines the methods that all state types must implement
|
||||
type State interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
// CleanableState interface extends State with cleanup capability
|
||||
type CleanableState interface {
|
||||
State
|
||||
Cleanup() error
|
||||
}
|
||||
|
||||
// RawState wraps raw JSON data for unregistered states
|
||||
type RawState struct {
|
||||
data json.RawMessage
|
||||
}
|
||||
|
||||
func (r *RawState) Name() string {
|
||||
return "" // This is a placeholder implementation
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler to preserve the original JSON
|
||||
func (r *RawState) MarshalJSON() ([]byte, error) {
|
||||
return r.data, nil
|
||||
}
|
||||
|
||||
// Manager handles the persistence and management of various states
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
@@ -98,15 +73,15 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
m.cancel()
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -164,7 +139,7 @@ func (m *Manager) setState(name string, state State) error {
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.states[name]; !exists {
|
||||
return fmt.Errorf(errStateNotRegistered, name)
|
||||
return fmt.Errorf("state %s not registered", name)
|
||||
}
|
||||
|
||||
m.states[name] = state
|
||||
@@ -173,63 +148,6 @@ func (m *Manager) setState(name string, state State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteStateByName handles deletion of states without cleanup.
|
||||
// It doesn't require the state to be registered.
|
||||
func (m *Manager) DeleteStateByName(stateName string) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
rawStates, err := m.loadStateFile(false)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errLoadStateFile, err)
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, exists := rawStates[stateName]; !exists {
|
||||
return fmt.Errorf("state %s not found", stateName)
|
||||
}
|
||||
|
||||
// Mark state as deleted by setting it to nil and marking it dirty
|
||||
m.states[stateName] = nil
|
||||
m.dirty[stateName] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAllStates removes all states.
|
||||
func (m *Manager) DeleteAllStates() (int, error) {
|
||||
if m == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
rawStates, err := m.loadStateFile(false)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(errLoadStateFile, err)
|
||||
}
|
||||
if rawStates == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
count := len(rawStates)
|
||||
|
||||
// Mark all states as deleted and dirty
|
||||
for name := range rawStates {
|
||||
m.states[name] = nil
|
||||
m.dirty[name] = struct{}{}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *Manager) periodicStateSave(ctx context.Context) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -260,18 +178,25 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
bs, err := marshalWithPanicRecovery(m.states)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal states: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
start := time.Now()
|
||||
|
||||
go func() {
|
||||
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
||||
data, err := json.MarshalIndent(m.states, "", " ")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("marshal states: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:gosec
|
||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
||||
done <- fmt.Errorf("write state file: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -283,175 +208,63 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
||||
|
||||
clear(m.dirty)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadStateFile reads and unmarshals the state file into a map of raw JSON messages
|
||||
func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) {
|
||||
// loadState loads the existing state from the state file
|
||||
func (m *Manager) loadState() error {
|
||||
data, err := os.ReadFile(m.filePath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
log.Debug("state file does not exist")
|
||||
return nil, nil // nolint:nilnil
|
||||
return nil
|
||||
}
|
||||
return nil, fmt.Errorf("read state file: %w", err)
|
||||
return fmt.Errorf("read state file: %w", err)
|
||||
}
|
||||
|
||||
var rawStates map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawStates); err != nil {
|
||||
if deleteCorrupt {
|
||||
log.Warn("State file appears to be corrupted, attempting to delete it", err)
|
||||
if err := os.Remove(m.filePath); err != nil {
|
||||
log.Errorf("Failed to delete corrupted state file: %v", err)
|
||||
} else {
|
||||
log.Info("State file deleted")
|
||||
}
|
||||
log.Warn("State file appears to be corrupted, attempting to delete it")
|
||||
if err := os.Remove(m.filePath); err != nil {
|
||||
log.Errorf("Failed to delete corrupted state file: %v", err)
|
||||
} else {
|
||||
log.Info("State file deleted")
|
||||
}
|
||||
return nil, fmt.Errorf("unmarshal states: %w", err)
|
||||
return fmt.Errorf("unmarshal states: %w", err)
|
||||
}
|
||||
|
||||
return rawStates, nil
|
||||
}
|
||||
var merr *multierror.Error
|
||||
|
||||
// loadSingleRawState unmarshals a raw state into a concrete state object
|
||||
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
|
||||
stateType, ok := m.stateTypes[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(errStateNotRegistered, name)
|
||||
}
|
||||
for name, rawState := range rawStates {
|
||||
stateType, ok := m.stateTypes[name]
|
||||
if !ok {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name))
|
||||
continue
|
||||
}
|
||||
|
||||
if string(rawState) == "null" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
if string(rawState) == "null" {
|
||||
continue
|
||||
}
|
||||
|
||||
statePtr := reflect.New(stateType).Interface().(State)
|
||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||
}
|
||||
statePtr := reflect.New(stateType).Interface().(State)
|
||||
if err := json.Unmarshal(rawState, statePtr); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
return statePtr, nil
|
||||
}
|
||||
|
||||
// LoadState loads a specific state from the state file
|
||||
func (m *Manager) LoadState(state State) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
rawStates, err := m.loadStateFile(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := state.Name()
|
||||
rawState, exists := rawStates[name]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.states[name] = loadedState
|
||||
if loadedState != nil {
|
||||
m.states[name] = statePtr
|
||||
log.Debugf("loaded state: %s", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// cleanupSingleState handles the cleanup of a specific state and returns any error.
|
||||
// The caller must hold the mutex.
|
||||
func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error {
|
||||
// For unregistered states, preserve the raw JSON
|
||||
if _, registered := m.stateTypes[name]; !registered {
|
||||
m.states[name] = &RawState{data: rawState}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load the state
|
||||
loadedState, err := m.loadSingleRawState(name, rawState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if loadedState == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if state supports cleanup
|
||||
cleanableState, isCleanable := loadedState.(CleanableState)
|
||||
if !isCleanable {
|
||||
// If it doesn't support cleanup, keep it as-is
|
||||
m.states[name] = loadedState
|
||||
return nil
|
||||
}
|
||||
|
||||
// Perform cleanup
|
||||
log.Infof("cleaning up state %s", name)
|
||||
if err := cleanableState.Cleanup(); err != nil {
|
||||
// On cleanup error, preserve the state
|
||||
m.states[name] = loadedState
|
||||
return fmt.Errorf("cleanup state: %w", err)
|
||||
}
|
||||
|
||||
// Successfully cleaned up - mark for deletion
|
||||
m.states[name] = nil
|
||||
m.dirty[name] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState.
|
||||
// Returns an error if the state doesn't exist, isn't registered, or cleanup fails.
|
||||
func (m *Manager) CleanupStateByName(name string) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if state is registered
|
||||
if _, registered := m.stateTypes[name]; !registered {
|
||||
return fmt.Errorf(errStateNotRegistered, name)
|
||||
}
|
||||
|
||||
// Load raw states from file
|
||||
rawStates, err := m.loadStateFile(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if state exists in file
|
||||
rawState, exists := rawStates[name]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.cleanupSingleState(name, rawState); err != nil {
|
||||
return fmt.Errorf("%s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it.
|
||||
// Unregistered states are preserved in their original state.
|
||||
// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them.
|
||||
// If the cleanup is successful, the state is marked for deletion.
|
||||
func (m *Manager) PerformCleanup() error {
|
||||
if m == nil {
|
||||
return nil
|
||||
@@ -460,63 +273,26 @@ func (m *Manager) PerformCleanup() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Load raw states from file
|
||||
rawStates, err := m.loadStateFile(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errLoadStateFile, err)
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil
|
||||
if err := m.loadState(); err != nil {
|
||||
log.Warnf("Failed to load state during cleanup: %v", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for name, state := range m.states {
|
||||
if state == nil {
|
||||
// If no state was found in the state file, we don't mark the state dirty nor return an error
|
||||
continue
|
||||
}
|
||||
|
||||
// Process each state in the file
|
||||
for name, rawState := range rawStates {
|
||||
if err := m.cleanupSingleState(name, rawState); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err))
|
||||
log.Infof("client was not shut down properly, cleaning up %s", name)
|
||||
if err := state.Cleanup(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err))
|
||||
} else {
|
||||
// mark for deletion on cleanup success
|
||||
m.states[name] = nil
|
||||
m.dirty[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// GetSavedStateNames returns all state names that are currently saved in the state file.
|
||||
func (m *Manager) GetSavedStateNames() ([]string, error) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rawStates, err := m.loadStateFile(false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errLoadStateFile, err)
|
||||
}
|
||||
if rawStates == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var states []string
|
||||
for name, state := range rawStates {
|
||||
if len(state) != 0 && string(state) != "null" {
|
||||
states = append(states, name)
|
||||
}
|
||||
}
|
||||
|
||||
return states, nil
|
||||
}
|
||||
|
||||
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
||||
var bs []byte
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during marshal: %v", r)
|
||||
}
|
||||
}()
|
||||
bs, err = json.Marshal(v)
|
||||
}()
|
||||
|
||||
return bs, err
|
||||
}
|
||||
|
||||
@@ -4,20 +4,32 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||
// It returns an empty string if the path cannot be determined.
|
||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
||||
func GetDefaultStatePath() string {
|
||||
var path string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
case "darwin", "linux":
|
||||
return "/var/lib/netbird/state.json"
|
||||
path = "/var/lib/netbird/state.json"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
return "/var/db/netbird/state.json"
|
||||
path = "/var/db/netbird/state.json"
|
||||
// ios/android don't need state
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
return ""
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v4.23.4
|
||||
// protoc v3.21.12
|
||||
// source: daemon.proto
|
||||
|
||||
package proto
|
||||
@@ -2103,434 +2103,6 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{30}
|
||||
}
|
||||
|
||||
// State represents a daemon state entry
|
||||
type State struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
|
||||
}
|
||||
|
||||
func (x *State) Reset() {
|
||||
*x = State{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[31]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *State) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*State) ProtoMessage() {}
|
||||
|
||||
func (x *State) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[31]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use State.ProtoReflect.Descriptor instead.
|
||||
func (*State) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{31}
|
||||
}
|
||||
|
||||
func (x *State) GetName() string {
|
||||
if x != nil {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ListStatesRequest is empty as it requires no parameters
|
||||
type ListStatesRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *ListStatesRequest) Reset() {
|
||||
*x = ListStatesRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[32]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ListStatesRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListStatesRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ListStatesRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[32]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ListStatesRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{32}
|
||||
}
|
||||
|
||||
// ListStatesResponse contains a list of states
|
||||
type ListStatesResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
States []*State `protobuf:"bytes,1,rep,name=states,proto3" json:"states,omitempty"`
|
||||
}
|
||||
|
||||
func (x *ListStatesResponse) Reset() {
|
||||
*x = ListStatesResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[33]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ListStatesResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListStatesResponse) ProtoMessage() {}
|
||||
|
||||
func (x *ListStatesResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[33]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead.
|
||||
func (*ListStatesResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{33}
|
||||
}
|
||||
|
||||
func (x *ListStatesResponse) GetStates() []*State {
|
||||
if x != nil {
|
||||
return x.States
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanStateRequest for cleaning states
|
||||
type CleanStateRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
|
||||
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
|
||||
}
|
||||
|
||||
func (x *CleanStateRequest) Reset() {
|
||||
*x = CleanStateRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[34]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *CleanStateRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*CleanStateRequest) ProtoMessage() {}
|
||||
|
||||
func (x *CleanStateRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[34]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead.
|
||||
func (*CleanStateRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{34}
|
||||
}
|
||||
|
||||
func (x *CleanStateRequest) GetStateName() string {
|
||||
if x != nil {
|
||||
return x.StateName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *CleanStateRequest) GetAll() bool {
|
||||
if x != nil {
|
||||
return x.All
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CleanStateResponse contains the result of the clean operation
|
||||
type CleanStateResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
CleanedStates int32 `protobuf:"varint,1,opt,name=cleaned_states,json=cleanedStates,proto3" json:"cleaned_states,omitempty"`
|
||||
}
|
||||
|
||||
func (x *CleanStateResponse) Reset() {
|
||||
*x = CleanStateResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[35]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *CleanStateResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*CleanStateResponse) ProtoMessage() {}
|
||||
|
||||
func (x *CleanStateResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[35]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead.
|
||||
func (*CleanStateResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{35}
|
||||
}
|
||||
|
||||
func (x *CleanStateResponse) GetCleanedStates() int32 {
|
||||
if x != nil {
|
||||
return x.CleanedStates
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// DeleteStateRequest for deleting states
|
||||
type DeleteStateRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"`
|
||||
All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"`
|
||||
}
|
||||
|
||||
func (x *DeleteStateRequest) Reset() {
|
||||
*x = DeleteStateRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[36]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *DeleteStateRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*DeleteStateRequest) ProtoMessage() {}
|
||||
|
||||
func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[36]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead.
|
||||
func (*DeleteStateRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{36}
|
||||
}
|
||||
|
||||
func (x *DeleteStateRequest) GetStateName() string {
|
||||
if x != nil {
|
||||
return x.StateName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *DeleteStateRequest) GetAll() bool {
|
||||
if x != nil {
|
||||
return x.All
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DeleteStateResponse contains the result of the delete operation
|
||||
type DeleteStateResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
DeletedStates int32 `protobuf:"varint,1,opt,name=deleted_states,json=deletedStates,proto3" json:"deleted_states,omitempty"`
|
||||
}
|
||||
|
||||
func (x *DeleteStateResponse) Reset() {
|
||||
*x = DeleteStateResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[37]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *DeleteStateResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*DeleteStateResponse) ProtoMessage() {}
|
||||
|
||||
func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[37]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead.
|
||||
func (*DeleteStateResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{37}
|
||||
}
|
||||
|
||||
func (x *DeleteStateResponse) GetDeletedStates() int32 {
|
||||
if x != nil {
|
||||
return x.DeletedStates
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type SetNetworkMapPersistenceRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
func (x *SetNetworkMapPersistenceRequest) Reset() {
|
||||
*x = SetNetworkMapPersistenceRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[38]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SetNetworkMapPersistenceRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SetNetworkMapPersistenceRequest) ProtoMessage() {}
|
||||
|
||||
func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[38]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead.
|
||||
func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{38}
|
||||
}
|
||||
|
||||
func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool {
|
||||
if x != nil {
|
||||
return x.Enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type SetNetworkMapPersistenceResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *SetNetworkMapPersistenceResponse) Reset() {
|
||||
*x = SetNetworkMapPersistenceResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[39]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *SetNetworkMapPersistenceResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*SetNetworkMapPersistenceResponse) ProtoMessage() {}
|
||||
|
||||
func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[39]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead.
|
||||
func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{39}
|
||||
}
|
||||
|
||||
var File_daemon_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_daemon_proto_rawDesc = []byte{
|
||||
@@ -2827,116 +2399,66 @@ var file_daemon_proto_rawDesc = []byte{
|
||||
0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
|
||||
0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74,
|
||||
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d,
|
||||
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a,
|
||||
0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74,
|
||||
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22,
|
||||
0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61,
|
||||
0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e,
|
||||
0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08,
|
||||
0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74,
|
||||
0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63,
|
||||
0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74,
|
||||
0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74,
|
||||
0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74,
|
||||
0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74,
|
||||
0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c,
|
||||
0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74,
|
||||
0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65,
|
||||
0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65,
|
||||
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65,
|
||||
0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e,
|
||||
0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61,
|
||||
0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f,
|
||||
0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c,
|
||||
0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
|
||||
0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05,
|
||||
0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52,
|
||||
0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04,
|
||||
0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10,
|
||||
0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a,
|
||||
0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36,
|
||||
0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
||||
0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70,
|
||||
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53,
|
||||
0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69,
|
||||
0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
|
||||
0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
|
||||
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a,
|
||||
0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44,
|
||||
0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
|
||||
0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66,
|
||||
0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70,
|
||||
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
|
||||
0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69,
|
||||
0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
|
||||
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75,
|
||||
0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
|
||||
0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
|
||||
0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65,
|
||||
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74,
|
||||
0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62,
|
||||
0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
|
||||
0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
|
||||
0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
|
||||
0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b,
|
||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
|
||||
0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a,
|
||||
0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
|
||||
0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73,
|
||||
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53,
|
||||
0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
|
||||
0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74,
|
||||
0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
|
||||
0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65,
|
||||
0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
|
||||
0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e,
|
||||
0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12,
|
||||
0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41,
|
||||
0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41,
|
||||
0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
|
||||
0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
|
||||
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
|
||||
0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b,
|
||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
|
||||
0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
|
||||
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55,
|
||||
0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74,
|
||||
0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
|
||||
0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65,
|
||||
0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
||||
0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f,
|
||||
0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
|
||||
0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
|
||||
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
|
||||
0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53,
|
||||
0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
|
||||
0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65,
|
||||
0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
|
||||
0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70,
|
||||
0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d,
|
||||
0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
||||
0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69,
|
||||
0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
|
||||
0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52,
|
||||
0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
||||
0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65,
|
||||
0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
|
||||
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
||||
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
|
||||
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
||||
0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
|
||||
0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
|
||||
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c,
|
||||
0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47,
|
||||
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
|
||||
0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67,
|
||||
0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
|
||||
0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -2952,59 +2474,50 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32)
|
||||
var file_daemon_proto_goTypes = []interface{}{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(*LoginRequest)(nil), // 1: daemon.LoginRequest
|
||||
(*LoginResponse)(nil), // 2: daemon.LoginResponse
|
||||
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
|
||||
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
|
||||
(*UpRequest)(nil), // 5: daemon.UpRequest
|
||||
(*UpResponse)(nil), // 6: daemon.UpResponse
|
||||
(*StatusRequest)(nil), // 7: daemon.StatusRequest
|
||||
(*StatusResponse)(nil), // 8: daemon.StatusResponse
|
||||
(*DownRequest)(nil), // 9: daemon.DownRequest
|
||||
(*DownResponse)(nil), // 10: daemon.DownResponse
|
||||
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
|
||||
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
|
||||
(*PeerState)(nil), // 13: daemon.PeerState
|
||||
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
|
||||
(*SignalState)(nil), // 15: daemon.SignalState
|
||||
(*ManagementState)(nil), // 16: daemon.ManagementState
|
||||
(*RelayState)(nil), // 17: daemon.RelayState
|
||||
(*NSGroupState)(nil), // 18: daemon.NSGroupState
|
||||
(*FullStatus)(nil), // 19: daemon.FullStatus
|
||||
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
|
||||
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
|
||||
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
|
||||
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
|
||||
(*IPList)(nil), // 24: daemon.IPList
|
||||
(*Route)(nil), // 25: daemon.Route
|
||||
(*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest
|
||||
(*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse
|
||||
(*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest
|
||||
(*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse
|
||||
(*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse
|
||||
(*State)(nil), // 32: daemon.State
|
||||
(*ListStatesRequest)(nil), // 33: daemon.ListStatesRequest
|
||||
(*ListStatesResponse)(nil), // 34: daemon.ListStatesResponse
|
||||
(*CleanStateRequest)(nil), // 35: daemon.CleanStateRequest
|
||||
(*CleanStateResponse)(nil), // 36: daemon.CleanStateResponse
|
||||
(*DeleteStateRequest)(nil), // 37: daemon.DeleteStateRequest
|
||||
(*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse
|
||||
(*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest
|
||||
(*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse
|
||||
nil, // 41: daemon.Route.ResolvedIPsEntry
|
||||
(*durationpb.Duration)(nil), // 42: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(*LoginRequest)(nil), // 1: daemon.LoginRequest
|
||||
(*LoginResponse)(nil), // 2: daemon.LoginResponse
|
||||
(*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest
|
||||
(*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse
|
||||
(*UpRequest)(nil), // 5: daemon.UpRequest
|
||||
(*UpResponse)(nil), // 6: daemon.UpResponse
|
||||
(*StatusRequest)(nil), // 7: daemon.StatusRequest
|
||||
(*StatusResponse)(nil), // 8: daemon.StatusResponse
|
||||
(*DownRequest)(nil), // 9: daemon.DownRequest
|
||||
(*DownResponse)(nil), // 10: daemon.DownResponse
|
||||
(*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest
|
||||
(*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse
|
||||
(*PeerState)(nil), // 13: daemon.PeerState
|
||||
(*LocalPeerState)(nil), // 14: daemon.LocalPeerState
|
||||
(*SignalState)(nil), // 15: daemon.SignalState
|
||||
(*ManagementState)(nil), // 16: daemon.ManagementState
|
||||
(*RelayState)(nil), // 17: daemon.RelayState
|
||||
(*NSGroupState)(nil), // 18: daemon.NSGroupState
|
||||
(*FullStatus)(nil), // 19: daemon.FullStatus
|
||||
(*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest
|
||||
(*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse
|
||||
(*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest
|
||||
(*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse
|
||||
(*IPList)(nil), // 24: daemon.IPList
|
||||
(*Route)(nil), // 25: daemon.Route
|
||||
(*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest
|
||||
(*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse
|
||||
(*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest
|
||||
(*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse
|
||||
(*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse
|
||||
nil, // 32: daemon.Route.ResolvedIPsEntry
|
||||
(*durationpb.Duration)(nil), // 33: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
@@ -3012,48 +2525,39 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
|
||||
41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry
|
||||
32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry
|
||||
0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State
|
||||
24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
|
||||
22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest
|
||||
35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest
|
||||
37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest
|
||||
2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
|
||||
23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse
|
||||
33, // [33:49] is the sub-list for method output_type
|
||||
17, // [17:33] is the sub-list for method input_type
|
||||
17, // [17:17] is the sub-list for extension type_name
|
||||
17, // [17:17] is the sub-list for extension extendee
|
||||
0, // [0:17] is the sub-list for field type_name
|
||||
24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
|
||||
22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
|
||||
23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
28, // [28:40] is the sub-list for method output_type
|
||||
16, // [16:28] is the sub-list for method input_type
|
||||
16, // [16:16] is the sub-list for extension type_name
|
||||
16, // [16:16] is the sub-list for extension extendee
|
||||
0, // [0:16] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_daemon_proto_init() }
|
||||
@@ -3434,114 +2938,6 @@ func file_daemon_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*State); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ListStatesRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ListStatesResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*CleanStateRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*CleanStateResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*DeleteStateRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*DeleteStateResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetNetworkMapPersistenceRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetNetworkMapPersistenceResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{}
|
||||
type x struct{}
|
||||
@@ -3550,7 +2946,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_daemon_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 41,
|
||||
NumMessages: 32,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -45,20 +45,7 @@ service DaemonService {
|
||||
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
|
||||
|
||||
// List all states
|
||||
rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {}
|
||||
|
||||
// Clean specific state or all states
|
||||
rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {}
|
||||
|
||||
// Delete specific state or all states
|
||||
rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {}
|
||||
|
||||
// SetNetworkMapPersistence enables or disables network map persistence
|
||||
rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
message LoginRequest {
|
||||
// setupKey wiretrustee setup key.
|
||||
@@ -306,46 +293,4 @@ message SetLogLevelRequest {
|
||||
}
|
||||
|
||||
message SetLogLevelResponse {
|
||||
}
|
||||
|
||||
// State represents a daemon state entry
|
||||
message State {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
// ListStatesRequest is empty as it requires no parameters
|
||||
message ListStatesRequest {}
|
||||
|
||||
// ListStatesResponse contains a list of states
|
||||
message ListStatesResponse {
|
||||
repeated State states = 1;
|
||||
}
|
||||
|
||||
// CleanStateRequest for cleaning states
|
||||
message CleanStateRequest {
|
||||
string state_name = 1;
|
||||
bool all = 2;
|
||||
}
|
||||
|
||||
// CleanStateResponse contains the result of the clean operation
|
||||
message CleanStateResponse {
|
||||
int32 cleaned_states = 1;
|
||||
}
|
||||
|
||||
// DeleteStateRequest for deleting states
|
||||
message DeleteStateRequest {
|
||||
string state_name = 1;
|
||||
bool all = 2;
|
||||
}
|
||||
|
||||
// DeleteStateResponse contains the result of the delete operation
|
||||
message DeleteStateResponse {
|
||||
int32 deleted_states = 1;
|
||||
}
|
||||
|
||||
|
||||
message SetNetworkMapPersistenceRequest {
|
||||
bool enabled = 1;
|
||||
}
|
||||
|
||||
message SetNetworkMapPersistenceResponse {}
|
||||
}
|
||||
@@ -43,14 +43,6 @@ type DaemonServiceClient interface {
|
||||
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
|
||||
// List all states
|
||||
ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error)
|
||||
// Clean specific state or all states
|
||||
CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error)
|
||||
// Delete specific state or all states
|
||||
DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error)
|
||||
// SetNetworkMapPersistence enables or disables network map persistence
|
||||
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -169,42 +161,6 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
|
||||
out := new(ListStatesResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
|
||||
out := new(CleanStateResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
|
||||
out := new(DeleteStateResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
|
||||
out := new(SetNetworkMapPersistenceResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -234,14 +190,6 @@ type DaemonServiceServer interface {
|
||||
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
|
||||
// List all states
|
||||
ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error)
|
||||
// Clean specific state or all states
|
||||
CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error)
|
||||
// Delete specific state or all states
|
||||
DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error)
|
||||
// SetNetworkMapPersistence enables or disables network map persistence
|
||||
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -285,18 +233,6 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve
|
||||
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -526,78 +462,6 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListStatesRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).ListStates(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/ListStates",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(CleanStateRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).CleanState(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/CleanState",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(DeleteStateRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).DeleteState(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/DeleteState",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SetNetworkMapPersistenceRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -653,22 +517,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "SetLogLevel",
|
||||
Handler: _DaemonService_SetLogLevel_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ListStates",
|
||||
Handler: _DaemonService_ListStates_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "CleanState",
|
||||
Handler: _DaemonService_CleanState_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "DeleteState",
|
||||
Handler: _DaemonService_DeleteState_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SetNetworkMapPersistence",
|
||||
Handler: _DaemonService_SetNetworkMapPersistence_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "daemon.proto",
|
||||
|
||||
@@ -5,44 +5,32 @@ package server
|
||||
import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
This debug bundle contains the following files:
|
||||
|
||||
status.txt: Anonymized status information of the NetBird client.
|
||||
client.log: Most recent, anonymized client log file of the NetBird client.
|
||||
netbird.err: Most recent, anonymized stderr log file of the NetBird client.
|
||||
netbird.out: Most recent, anonymized stdout log file of the NetBird client.
|
||||
client.log: Most recent, anonymized log file of the NetBird client.
|
||||
routes.txt: Anonymized system routes, if --system-info flag was provided.
|
||||
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
state.json: Anonymized client state dump containing netbird states.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -62,32 +50,8 @@ Domains
|
||||
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
|
||||
Reoccuring domain names are replaced with the same anonymized domain.
|
||||
|
||||
Network Map
|
||||
The network_map.json file contains the following anonymized information:
|
||||
- Peer configurations (addresses, FQDNs, DNS settings)
|
||||
- Remote and offline peer information (allowed IPs, FQDNs)
|
||||
- Routes (network ranges, associated domains)
|
||||
- DNS configuration (nameservers, domains, custom zones)
|
||||
- Firewall rules (peer IPs, source/destination ranges)
|
||||
|
||||
SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above.
|
||||
|
||||
State File
|
||||
The state.json file contains anonymized internal state information of the NetBird client, including:
|
||||
- DNS settings and configuration
|
||||
- Firewall rules
|
||||
- Exclusion routes
|
||||
- Route selection
|
||||
- Other internal states that may be present
|
||||
|
||||
The state file follows the same anonymization rules as other files:
|
||||
- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure
|
||||
- Domain names are consistently anonymized
|
||||
- Technical identifiers and non-sensitive data remain unchanged
|
||||
|
||||
Routes
|
||||
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
|
||||
|
||||
Network Interfaces
|
||||
The interfaces.txt file contains information about network interfaces, including:
|
||||
- Interface name
|
||||
@@ -108,12 +72,6 @@ The config.txt file contains anonymized configuration information of the NetBird
|
||||
Other non-sensitive configuration options are included without anonymization.
|
||||
`
|
||||
|
||||
const (
|
||||
clientLogFile = "client.log"
|
||||
errorLogFile = "netbird.err"
|
||||
stdoutLogFile = "netbird.out"
|
||||
)
|
||||
|
||||
// DebugBundle creates a debug bundle and returns the location.
|
||||
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
|
||||
s.mutex.Lock()
|
||||
@@ -161,27 +119,19 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
|
||||
seedFromStatus(anonymizer, &status)
|
||||
|
||||
if err := s.addConfig(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add config to debug bundle: %v", err)
|
||||
return fmt.Errorf("add config: %w", err)
|
||||
}
|
||||
|
||||
if req.GetSystemInfo() {
|
||||
if err := s.addRoutes(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add routes to debug bundle: %v", err)
|
||||
return fmt.Errorf("add routes: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
|
||||
return fmt.Errorf("add interfaces: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add network map: %w", err)
|
||||
}
|
||||
|
||||
if err := s.addStateFile(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add state file to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addLogfile(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
@@ -270,16 +220,15 @@ func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
|
||||
}
|
||||
|
||||
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
routes, err := systemops.GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
// TODO: get routes including nexthop
|
||||
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
|
||||
routesReader := strings.NewReader(routesContent)
|
||||
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
|
||||
return fmt.Errorf("add routes file to zip: %w", err)
|
||||
if routes, err := systemops.GetRoutesFromTable(); err != nil {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
} else {
|
||||
// TODO: get routes including nexthop
|
||||
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
|
||||
routesReader := strings.NewReader(routesContent)
|
||||
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
|
||||
return fmt.Errorf("add routes file to zip: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -299,106 +248,14 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
networkMap, err := s.getLatestNetworkMap()
|
||||
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
|
||||
logFile, err := os.Open(s.logFile)
|
||||
if err != nil {
|
||||
// Skip if network map is not available, but log it
|
||||
log.Debugf("skipping empty network map in debug bundle: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if req.GetAnonymize() {
|
||||
if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil {
|
||||
return fmt.Errorf("anonymize network map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
options := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
Indent: " ",
|
||||
AllowPartial: true,
|
||||
}
|
||||
|
||||
jsonBytes, err := options.Marshal(networkMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate json: %w", err)
|
||||
}
|
||||
|
||||
if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil {
|
||||
return fmt.Errorf("add network map to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
path := statemanager.GetDefaultStatePath()
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read state file: %w", err)
|
||||
}
|
||||
|
||||
if req.GetAnonymize() {
|
||||
var rawStates map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawStates); err != nil {
|
||||
return fmt.Errorf("unmarshal states: %w", err)
|
||||
}
|
||||
|
||||
if err := anonymizeStateFile(&rawStates, anonymizer); err != nil {
|
||||
return fmt.Errorf("anonymize state file: %w", err)
|
||||
}
|
||||
|
||||
bs, err := json.MarshalIndent(rawStates, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal states: %w", err)
|
||||
}
|
||||
data = bs
|
||||
}
|
||||
|
||||
if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil {
|
||||
return fmt.Errorf("add state file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
logDir := filepath.Dir(s.logFile)
|
||||
|
||||
if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add client log file to zip: %w", err)
|
||||
}
|
||||
|
||||
errLogPath := filepath.Join(logDir, errorLogFile)
|
||||
if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil {
|
||||
log.Warnf("Failed to add %s to zip: %v", errorLogFile, err)
|
||||
}
|
||||
|
||||
stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
|
||||
if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil {
|
||||
log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addSingleLogfile adds a single log file to the archive
|
||||
func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
logFile, err := os.Open(logPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open log file %s: %w", targetName, err)
|
||||
return fmt.Errorf("open log file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := logFile.Close(); err != nil {
|
||||
log.Errorf("Failed to close log file %s: %v", targetName, err)
|
||||
log.Errorf("Failed to close original log file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -407,55 +264,45 @@ func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBu
|
||||
var writer *io.PipeWriter
|
||||
logReader, writer = io.Pipe()
|
||||
|
||||
go anonymizeLog(logFile, writer, anonymizer)
|
||||
go s.anonymize(logFile, writer, anonymizer)
|
||||
} else {
|
||||
logReader = logFile
|
||||
}
|
||||
|
||||
if err := addFileToZip(archive, logReader, targetName); err != nil {
|
||||
return fmt.Errorf("add %s to zip: %w", targetName, err)
|
||||
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
|
||||
return fmt.Errorf("add log file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled
|
||||
func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
if s.connectClient == nil {
|
||||
return nil, errors.New("connect client is not initialized")
|
||||
}
|
||||
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
|
||||
defer func() {
|
||||
// always nil
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine is not initialized")
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := anonymizer.AnonymizeString(scanner.Text())
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
networkMap, err := engine.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||
if err := scanner.Err(); err != nil {
|
||||
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if networkMap == nil {
|
||||
return nil, errors.New("network map is not available")
|
||||
}
|
||||
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
// GetLogLevel gets the current logging level for the server.
|
||||
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
level := ParseLogLevel(log.GetLevel().String())
|
||||
return &proto.GetLogLevelResponse{Level: level}, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the server.
|
||||
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
level, err := log.ParseLevel(req.Level.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid log level: %w", err)
|
||||
@@ -466,20 +313,6 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
return &proto.SetLogLevelResponse{}, nil
|
||||
}
|
||||
|
||||
// SetNetworkMapPersistence sets the network map persistence for the server.
|
||||
func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetworkMapPersistenceRequest) (*proto.SetNetworkMapPersistenceResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
enabled := req.GetEnabled()
|
||||
s.persistNetworkMap = enabled
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetNetworkMapPersistence(enabled)
|
||||
}
|
||||
|
||||
return &proto.SetNetworkMapPersistenceResponse{}, nil
|
||||
}
|
||||
|
||||
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
|
||||
header := &zip.FileHeader{
|
||||
Name: filename,
|
||||
@@ -625,26 +458,6 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
|
||||
defer func() {
|
||||
// always nil
|
||||
_ = writer.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(reader)
|
||||
for scanner.Scan() {
|
||||
line := anonymizer.AnonymizeString(scanner.Text())
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
|
||||
anonymizedIPs := make([]string, len(ips))
|
||||
for i, ip := range ips {
|
||||
@@ -671,248 +484,3 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s
|
||||
}
|
||||
return anonymizedIPs
|
||||
}
|
||||
|
||||
func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error {
|
||||
if networkMap.PeerConfig != nil {
|
||||
anonymizePeerConfig(networkMap.PeerConfig, anonymizer)
|
||||
}
|
||||
|
||||
for _, peer := range networkMap.RemotePeers {
|
||||
anonymizeRemotePeer(peer, anonymizer)
|
||||
}
|
||||
|
||||
for _, peer := range networkMap.OfflinePeers {
|
||||
anonymizeRemotePeer(peer, anonymizer)
|
||||
}
|
||||
|
||||
for _, r := range networkMap.Routes {
|
||||
anonymizeRoute(r, anonymizer)
|
||||
}
|
||||
|
||||
if networkMap.DNSConfig != nil {
|
||||
anonymizeDNSConfig(networkMap.DNSConfig, anonymizer)
|
||||
}
|
||||
|
||||
for _, rule := range networkMap.FirewallRules {
|
||||
anonymizeFirewallRule(rule, anonymizer)
|
||||
}
|
||||
|
||||
for _, rule := range networkMap.RoutesFirewallRules {
|
||||
anonymizeRouteFirewallRule(rule, anonymizer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if addr, err := netip.ParseAddr(config.Address); err == nil {
|
||||
config.Address = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
|
||||
if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 {
|
||||
config.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
|
||||
}
|
||||
|
||||
config.Dns = anonymizer.AnonymizeString(config.Dns)
|
||||
config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn)
|
||||
}
|
||||
|
||||
func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) {
|
||||
if peer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i, ip := range peer.AllowedIps {
|
||||
// Try to parse as prefix first (CIDR)
|
||||
if prefix, err := netip.ParsePrefix(ip); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
} else if addr, err := netip.ParseAddr(ip); err == nil {
|
||||
peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
}
|
||||
|
||||
peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn)
|
||||
|
||||
if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 {
|
||||
peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key")
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) {
|
||||
if route == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(route.Network); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
|
||||
for i, domain := range route.Domains {
|
||||
route.Domains[i] = anonymizer.AnonymizeDomain(domain)
|
||||
}
|
||||
|
||||
route.NetID = anonymizer.AnonymizeString(route.NetID)
|
||||
}
|
||||
|
||||
func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
anonymizeNameServerGroups(config.NameServerGroups, anonymizer)
|
||||
anonymizeCustomZones(config.CustomZones, anonymizer)
|
||||
}
|
||||
|
||||
func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) {
|
||||
for _, group := range groups {
|
||||
anonymizeServers(group.NameServers, anonymizer)
|
||||
anonymizeDomains(group.Domains, anonymizer)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) {
|
||||
for _, server := range servers {
|
||||
if addr, err := netip.ParseAddr(server.IP); err == nil {
|
||||
server.IP = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) {
|
||||
for i, domain := range domains {
|
||||
domains[i] = anonymizer.AnonymizeDomain(domain)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) {
|
||||
for _, zone := range zones {
|
||||
zone.Domain = anonymizer.AnonymizeDomain(zone.Domain)
|
||||
anonymizeRecords(zone.Records, anonymizer)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
|
||||
for _, record := range records {
|
||||
record.Name = anonymizer.AnonymizeDomain(record.Name)
|
||||
anonymizeRData(record, anonymizer)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) {
|
||||
switch record.Type {
|
||||
case 1, 28: // A or AAAA record
|
||||
if addr, err := netip.ParseAddr(record.RData); err == nil {
|
||||
record.RData = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
default:
|
||||
record.RData = anonymizer.AnonymizeString(record.RData)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) {
|
||||
if rule == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if addr, err := netip.ParseAddr(rule.PeerIP); err == nil {
|
||||
rule.PeerIP = anonymizer.AnonymizeIP(addr).String()
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) {
|
||||
if rule == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i, sourceRange := range rule.SourceRanges {
|
||||
if prefix, err := netip.ParsePrefix(sourceRange); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(rule.Destination); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error {
|
||||
for name, rawState := range *rawStates {
|
||||
if string(rawState) == "null" {
|
||||
continue
|
||||
}
|
||||
|
||||
var state map[string]any
|
||||
if err := json.Unmarshal(rawState, &state); err != nil {
|
||||
return fmt.Errorf("unmarshal state %s: %w", name, err)
|
||||
}
|
||||
|
||||
state = anonymizeValue(state, anonymizer).(map[string]any)
|
||||
|
||||
bs, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state %s: %w", name, err)
|
||||
}
|
||||
|
||||
(*rawStates)[name] = bs
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return anonymizeString(v, anonymizer)
|
||||
case map[string]any:
|
||||
return anonymizeMap(v, anonymizer)
|
||||
case []any:
|
||||
return anonymizeSlice(v, anonymizer)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string {
|
||||
if prefix, err := netip.ParsePrefix(v); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
if ip, err := netip.ParseAddr(v); err == nil {
|
||||
return anonymizer.AnonymizeIP(ip).String()
|
||||
}
|
||||
return anonymizer.AnonymizeString(v)
|
||||
}
|
||||
|
||||
func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any {
|
||||
result := make(map[string]any, len(v))
|
||||
for key, val := range v {
|
||||
newKey := anonymizeMapKey(key, anonymizer)
|
||||
result[newKey] = anonymizeValue(val, anonymizer)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string {
|
||||
if prefix, err := netip.ParsePrefix(key); err == nil {
|
||||
anonIP := anonymizer.AnonymizeIP(prefix.Addr())
|
||||
return fmt.Sprintf("%s/%d", anonIP, prefix.Bits())
|
||||
}
|
||||
if ip, err := netip.ParseAddr(key); err == nil {
|
||||
return anonymizer.AnonymizeIP(ip).String()
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any {
|
||||
for i, val := range v {
|
||||
v[i] = anonymizeValue(val, anonymizer)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -1,430 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
func TestAnonymizeStateFile(t *testing.T) {
|
||||
testState := map[string]json.RawMessage{
|
||||
"null_state": json.RawMessage("null"),
|
||||
"test_state": mustMarshal(map[string]any{
|
||||
// Test simple fields
|
||||
"public_ip": "203.0.113.1",
|
||||
"private_ip": "192.168.1.1",
|
||||
"protected_ip": "100.64.0.1",
|
||||
"well_known_ip": "8.8.8.8",
|
||||
"ipv6_addr": "2001:db8::1",
|
||||
"private_ipv6": "fd00::1",
|
||||
"domain": "test.example.com",
|
||||
"uri": "stun:stun.example.com:3478",
|
||||
"uri_with_ip": "turn:203.0.113.1:3478",
|
||||
"netbird_domain": "device.netbird.cloud",
|
||||
|
||||
// Test CIDR ranges
|
||||
"public_cidr": "203.0.113.0/24",
|
||||
"private_cidr": "192.168.0.0/16",
|
||||
"protected_cidr": "100.64.0.0/10",
|
||||
"ipv6_cidr": "2001:db8::/32",
|
||||
"private_ipv6_cidr": "fd00::/8",
|
||||
|
||||
// Test nested structures
|
||||
"nested": map[string]any{
|
||||
"ip": "203.0.113.2",
|
||||
"domain": "nested.example.com",
|
||||
"more_nest": map[string]any{
|
||||
"ip": "203.0.113.3",
|
||||
"domain": "deep.example.com",
|
||||
},
|
||||
},
|
||||
|
||||
// Test arrays
|
||||
"string_array": []any{
|
||||
"203.0.113.4",
|
||||
"test1.example.com",
|
||||
"test2.example.com",
|
||||
},
|
||||
"object_array": []any{
|
||||
map[string]any{
|
||||
"ip": "203.0.113.5",
|
||||
"domain": "array1.example.com",
|
||||
},
|
||||
map[string]any{
|
||||
"ip": "203.0.113.6",
|
||||
"domain": "array2.example.com",
|
||||
},
|
||||
},
|
||||
|
||||
// Test multiple occurrences of same value
|
||||
"duplicate_ip": "203.0.113.1", // Same as public_ip
|
||||
"duplicate_domain": "test.example.com", // Same as domain
|
||||
|
||||
// Test URIs with various schemes
|
||||
"stun_uri": "stun:stun.example.com:3478",
|
||||
"turns_uri": "turns:turns.example.com:5349",
|
||||
"http_uri": "http://web.example.com:80",
|
||||
"https_uri": "https://secure.example.com:443",
|
||||
|
||||
// Test strings that might look like IPs but aren't
|
||||
"not_ip": "300.300.300.300",
|
||||
"partial_ip": "192.168",
|
||||
"ip_like_string": "1234.5678",
|
||||
|
||||
// Test mixed content strings
|
||||
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
|
||||
|
||||
// Test empty and special values
|
||||
"empty_string": "",
|
||||
"null_value": nil,
|
||||
"numeric_value": 42,
|
||||
"boolean_value": true,
|
||||
}),
|
||||
"route_state": mustMarshal(map[string]any{
|
||||
"routes": []any{
|
||||
map[string]any{
|
||||
"network": "203.0.113.0/24",
|
||||
"gateway": "203.0.113.1",
|
||||
"domains": []any{
|
||||
"route1.example.com",
|
||||
"route2.example.com",
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"network": "2001:db8::/32",
|
||||
"gateway": "2001:db8::1",
|
||||
"domains": []any{
|
||||
"route3.example.com",
|
||||
"route4.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
// Test map with IP/CIDR keys
|
||||
"refCountMap": map[string]any{
|
||||
"203.0.113.1/32": map[string]any{
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "192.168.0.1",
|
||||
"Intf": map[string]any{
|
||||
"Name": "eth0",
|
||||
"Index": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"2001:db8::1/128": map[string]any{
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "fe80::1",
|
||||
"Intf": map[string]any{
|
||||
"Name": "eth0",
|
||||
"Index": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "192.168.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
// Pre-seed the domains we need to verify in the test assertions
|
||||
anonymizer.AnonymizeDomain("test.example.com")
|
||||
anonymizer.AnonymizeDomain("nested.example.com")
|
||||
anonymizer.AnonymizeDomain("deep.example.com")
|
||||
anonymizer.AnonymizeDomain("array1.example.com")
|
||||
|
||||
err := anonymizeStateFile(&testState, anonymizer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Helper function to unmarshal and get nested values
|
||||
var state map[string]any
|
||||
err = json.Unmarshal(testState["test_state"], &state)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test null state remains unchanged
|
||||
require.Equal(t, "null", string(testState["null_state"]))
|
||||
|
||||
// Basic assertions
|
||||
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
|
||||
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
|
||||
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
|
||||
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
|
||||
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
|
||||
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
|
||||
assert.NotEqual(t, "test.example.com", state["domain"])
|
||||
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
|
||||
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
|
||||
|
||||
// CIDR ranges
|
||||
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
|
||||
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
|
||||
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
|
||||
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
|
||||
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
|
||||
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
|
||||
|
||||
// Nested structures
|
||||
nested := state["nested"].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.2", nested["ip"])
|
||||
assert.NotEqual(t, "nested.example.com", nested["domain"])
|
||||
moreNest := nested["more_nest"].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
|
||||
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
|
||||
|
||||
// Arrays
|
||||
strArray := state["string_array"].([]any)
|
||||
assert.NotEqual(t, "203.0.113.4", strArray[0])
|
||||
assert.NotEqual(t, "test1.example.com", strArray[1])
|
||||
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
|
||||
|
||||
objArray := state["object_array"].([]any)
|
||||
firstObj := objArray[0].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
|
||||
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
|
||||
|
||||
// Duplicate values should be anonymized consistently
|
||||
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
|
||||
assert.Equal(t, state["domain"], state["duplicate_domain"])
|
||||
|
||||
// URIs
|
||||
assert.NotContains(t, state["stun_uri"], "stun.example.com")
|
||||
assert.NotContains(t, state["turns_uri"], "turns.example.com")
|
||||
assert.NotContains(t, state["http_uri"], "web.example.com")
|
||||
assert.NotContains(t, state["https_uri"], "secure.example.com")
|
||||
|
||||
// Non-IP strings should remain unchanged
|
||||
assert.Equal(t, "300.300.300.300", state["not_ip"])
|
||||
assert.Equal(t, "192.168", state["partial_ip"])
|
||||
assert.Equal(t, "1234.5678", state["ip_like_string"])
|
||||
|
||||
// Mixed content should have IPs and domains replaced
|
||||
mixedContent := state["mixed_content"].(string)
|
||||
assert.NotContains(t, mixedContent, "203.0.113.1")
|
||||
assert.NotContains(t, mixedContent, "test.example.com")
|
||||
assert.Contains(t, mixedContent, "Server at ")
|
||||
assert.Contains(t, mixedContent, " on port 80")
|
||||
|
||||
// Special values should remain unchanged
|
||||
assert.Equal(t, "", state["empty_string"])
|
||||
assert.Nil(t, state["null_value"])
|
||||
assert.Equal(t, float64(42), state["numeric_value"])
|
||||
assert.Equal(t, true, state["boolean_value"])
|
||||
|
||||
// Check route state
|
||||
var routeState map[string]any
|
||||
err = json.Unmarshal(testState["route_state"], &routeState)
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := routeState["routes"].([]any)
|
||||
route1 := routes[0].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
|
||||
assert.Contains(t, route1["network"], "/24")
|
||||
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
|
||||
domains := route1["domains"].([]any)
|
||||
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
|
||||
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
|
||||
|
||||
// Check map keys are anonymized
|
||||
refCountMap := routeState["refCountMap"].(map[string]any)
|
||||
hasPublicIPKey := false
|
||||
hasIPv6Key := false
|
||||
hasPrivateIPKey := false
|
||||
for key := range refCountMap {
|
||||
if strings.Contains(key, "203.0.113.1") {
|
||||
hasPublicIPKey = true
|
||||
}
|
||||
if strings.Contains(key, "2001:db8::1") {
|
||||
hasIPv6Key = true
|
||||
}
|
||||
if key == "10.0.0.1/32" {
|
||||
hasPrivateIPKey = true
|
||||
}
|
||||
}
|
||||
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
|
||||
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
|
||||
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
|
||||
}
|
||||
|
||||
func mustMarshal(v any) json.RawMessage {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
Address: "203.0.113.5",
|
||||
Dns: "1.2.3.4",
|
||||
Fqdn: "peer1.corp.example.com",
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
|
||||
},
|
||||
},
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{
|
||||
AllowedIps: []string{
|
||||
"203.0.113.1/32",
|
||||
"2001:db8:1234::1/128",
|
||||
"192.168.1.1/32",
|
||||
"100.64.0.1/32",
|
||||
"10.0.0.1/32",
|
||||
},
|
||||
Fqdn: "peer2.corp.example.com",
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
|
||||
},
|
||||
},
|
||||
},
|
||||
Routes: []*mgmProto.Route{
|
||||
{
|
||||
Network: "197.51.100.0/24",
|
||||
Domains: []string{"prod.example.com", "staging.example.com"},
|
||||
NetID: "net-123abc",
|
||||
},
|
||||
},
|
||||
DNSConfig: &mgmProto.DNSConfig{
|
||||
NameServerGroups: []*mgmProto.NameServerGroup{
|
||||
{
|
||||
NameServers: []*mgmProto.NameServer{
|
||||
{IP: "8.8.8.8"},
|
||||
{IP: "1.1.1.1"},
|
||||
{IP: "203.0.113.53"},
|
||||
},
|
||||
Domains: []string{"example.com", "internal.example.com"},
|
||||
},
|
||||
},
|
||||
CustomZones: []*mgmProto.CustomZone{
|
||||
{
|
||||
Domain: "custom.example.com",
|
||||
Records: []*mgmProto.SimpleRecord{
|
||||
{
|
||||
Name: "www.custom.example.com",
|
||||
Type: 1,
|
||||
RData: "203.0.113.10",
|
||||
},
|
||||
{
|
||||
Name: "internal.custom.example.com",
|
||||
Type: 1,
|
||||
RData: "192.168.1.10",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create anonymizer with test addresses
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
// Anonymize the network map
|
||||
err := anonymizeNetworkMap(networkMap, anonymizer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test PeerConfig anonymization
|
||||
peerCfg := networkMap.PeerConfig
|
||||
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
|
||||
|
||||
// Verify DNS and FQDN are properly anonymized
|
||||
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
|
||||
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
|
||||
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
|
||||
|
||||
// Verify SSH key is replaced
|
||||
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
|
||||
|
||||
// Test RemotePeers anonymization
|
||||
remotePeer := networkMap.RemotePeers[0]
|
||||
|
||||
// Verify FQDN is anonymized
|
||||
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
|
||||
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
|
||||
|
||||
// Check that public IPs are anonymized but private IPs are preserved
|
||||
for _, allowedIP := range remotePeer.AllowedIps {
|
||||
ip, _, err := net.ParseCIDR(allowedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
if ip.IsPrivate() || isInCGNATRange(ip) {
|
||||
require.Contains(t, []string{
|
||||
"192.168.1.1/32",
|
||||
"100.64.0.1/32",
|
||||
"10.0.0.1/32",
|
||||
}, allowedIP)
|
||||
} else {
|
||||
require.NotContains(t, []string{
|
||||
"203.0.113.1/32",
|
||||
"2001:db8:1234::1/128",
|
||||
}, allowedIP)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Routes anonymization
|
||||
route := networkMap.Routes[0]
|
||||
require.NotEqual(t, "197.51.100.0/24", route.Network)
|
||||
for _, domain := range route.Domains {
|
||||
require.True(t, strings.HasSuffix(domain, ".domain"))
|
||||
require.NotContains(t, domain, "example.com")
|
||||
}
|
||||
|
||||
// Test DNS config anonymization
|
||||
dnsConfig := networkMap.DNSConfig
|
||||
nameServerGroup := dnsConfig.NameServerGroups[0]
|
||||
|
||||
// Verify well-known DNS servers are preserved
|
||||
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
|
||||
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
|
||||
|
||||
// Verify public DNS server is anonymized
|
||||
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
|
||||
|
||||
// Verify domains are anonymized
|
||||
for _, domain := range nameServerGroup.Domains {
|
||||
require.True(t, strings.HasSuffix(domain, ".domain"))
|
||||
require.NotContains(t, domain, "example.com")
|
||||
}
|
||||
|
||||
// Test CustomZones anonymization
|
||||
customZone := dnsConfig.CustomZones[0]
|
||||
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
|
||||
require.NotContains(t, customZone.Domain, "example.com")
|
||||
|
||||
// Verify records are properly anonymized
|
||||
for _, record := range customZone.Records {
|
||||
require.True(t, strings.HasSuffix(record.Name, ".domain"))
|
||||
require.NotContains(t, record.Name, "example.com")
|
||||
|
||||
ip := net.ParseIP(record.RData)
|
||||
if ip != nil {
|
||||
if !ip.IsPrivate() {
|
||||
require.NotEqual(t, "203.0.113.10", record.RData)
|
||||
} else {
|
||||
require.Equal(t, "192.168.1.10", record.RData)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
return cgnat.Contains(ip)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
func handlePanicLog() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
|
||||
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
|
||||
stdErrorHandle = ^uintptr(11)
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
|
||||
setStdHandleFn = kernel32.NewProc("SetStdHandle")
|
||||
)
|
||||
|
||||
func handlePanicLog() error {
|
||||
logPath := os.Getenv(windowsPanicLogEnvVar)
|
||||
if logPath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure the directory exists
|
||||
logDir := filepath.Dir(logPath)
|
||||
if err := os.MkdirAll(logDir, 0750); err != nil {
|
||||
return fmt.Errorf("create panic log directory: %w", err)
|
||||
}
|
||||
if err := util.EnforcePermission(logPath); err != nil {
|
||||
return fmt.Errorf("enforce permission on panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Open log file with append mode
|
||||
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open panic log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stderr to the file
|
||||
if err = redirectStderr(f); err != nil {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("failed to close file after redirect error: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("redirect stderr: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("successfully configured panic logging to: %s", logPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// redirectStderr redirects stderr to the provided file
|
||||
func redirectStderr(f *os.File) error {
|
||||
// Get the current process's stderr handle
|
||||
if err := setStdHandle(f); err != nil {
|
||||
return fmt.Errorf("failed to set stderr handle: %w", err)
|
||||
}
|
||||
|
||||
// Also set os.Stderr for Go's standard library
|
||||
os.Stderr = f
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setStdHandle(f *os.File) error {
|
||||
handle := f.Fd()
|
||||
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
|
||||
if r0 == 0 {
|
||||
if e1 != nil {
|
||||
return e1
|
||||
}
|
||||
return syscall.EINVAL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -68,8 +68,6 @@ type Server struct {
|
||||
relayProbe *internal.Probe
|
||||
wgProbe *internal.Probe
|
||||
lastProbe time.Time
|
||||
|
||||
persistNetworkMap bool
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -99,10 +97,6 @@ func (s *Server) Start() error {
|
||||
defer s.mutex.Unlock()
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
|
||||
if err := handlePanicLog(); err != nil {
|
||||
log.Warnf("failed to redirect stderr: %v", err)
|
||||
}
|
||||
|
||||
if err := restoreResidualState(s.rootCtx); err != nil {
|
||||
log.Warnf(errRestoreResidualState, err)
|
||||
}
|
||||
@@ -198,7 +192,6 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
|
||||
runOperation := func() error {
|
||||
log.Tracef("running client connection")
|
||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
||||
s.connectClient.SetNetworkMapPersistence(s.persistNetworkMap)
|
||||
|
||||
probes := internal.ProbeHolder{
|
||||
MgmProbe: s.mgmProbe,
|
||||
@@ -629,8 +622,6 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
|
||||
if s.actCancel == nil {
|
||||
return nil, fmt.Errorf("service is not up")
|
||||
}
|
||||
|
||||
@@ -5,112 +5,12 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// ListStates returns a list of all saved states
|
||||
func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) {
|
||||
mgr := statemanager.New(statemanager.GetDefaultStatePath())
|
||||
|
||||
stateNames, err := mgr.GetSavedStateNames()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get saved state names: %v", err)
|
||||
}
|
||||
|
||||
states := make([]*proto.State, 0, len(stateNames))
|
||||
for _, name := range stateNames {
|
||||
states = append(states, &proto.State{
|
||||
Name: name,
|
||||
})
|
||||
}
|
||||
|
||||
return &proto.ListStatesResponse{
|
||||
States: states,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CleanState handles cleaning of states (performing cleanup operations)
|
||||
func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) {
|
||||
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
if req.All {
|
||||
// Reuse existing cleanup logic for all states
|
||||
if err := restoreResidualState(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err)
|
||||
}
|
||||
|
||||
// Get count of cleaned states
|
||||
mgr := statemanager.New(statemanager.GetDefaultStatePath())
|
||||
stateNames, err := mgr.GetSavedStateNames()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err)
|
||||
}
|
||||
|
||||
return &proto.CleanStateResponse{
|
||||
CleanedStates: int32(len(stateNames)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handle single state cleanup
|
||||
mgr := statemanager.New(statemanager.GetDefaultStatePath())
|
||||
registerStates(mgr)
|
||||
|
||||
if err := mgr.CleanupStateByName(req.StateName); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to clean state %s: %v", req.StateName, err)
|
||||
}
|
||||
|
||||
if err := mgr.PersistState(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
|
||||
}
|
||||
|
||||
return &proto.CleanStateResponse{
|
||||
CleanedStates: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteState handles deletion of states without cleanup
|
||||
func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) {
|
||||
if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.")
|
||||
}
|
||||
|
||||
mgr := statemanager.New(statemanager.GetDefaultStatePath())
|
||||
|
||||
var count int
|
||||
var err error
|
||||
|
||||
if req.All {
|
||||
count, err = mgr.DeleteAllStates()
|
||||
} else {
|
||||
err = mgr.DeleteStateByName(req.StateName)
|
||||
if err == nil {
|
||||
count = 1
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to delete state: %v", err)
|
||||
}
|
||||
|
||||
// Persist the changes
|
||||
if err := mgr.PersistState(ctx); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err)
|
||||
}
|
||||
|
||||
return &proto.DeleteStateResponse{
|
||||
DeletedStates: int32(count),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// restoreResidualState checks if the client was not shut down in a clean way and restores residual if required.
|
||||
// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required.
|
||||
// Otherwise, we might not be able to connect to the management server to retrieve new config.
|
||||
func restoreResidualState(ctx context.Context) error {
|
||||
path := statemanager.GetDefaultStatePath()
|
||||
@@ -124,7 +24,6 @@ func restoreResidualState(ctx context.Context) error {
|
||||
registerStates(mgr)
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := mgr.PerformCleanup(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err))
|
||||
}
|
||||
|
||||
@@ -61,14 +61,6 @@ type Info struct {
|
||||
Files []File // for posture checks
|
||||
}
|
||||
|
||||
// StaticInfo is an object that contains machine information that does not change
|
||||
type StaticInfo struct {
|
||||
SystemSerialNumber string
|
||||
SystemProductName string
|
||||
SystemManufacturer string
|
||||
Environment Environment
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
func extractUserAgent(ctx context.Context) string {
|
||||
md, hasMeta := metadata.FromOutgoingContext(ctx)
|
||||
|
||||
@@ -10,12 +10,13 @@ import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -40,10 +41,11 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
serialNum, prodName, manufacturer := sysInfo()
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
@@ -55,10 +57,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: release,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build !android
|
||||
// +build !android
|
||||
|
||||
package system
|
||||
|
||||
@@ -15,13 +16,30 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/zcalusic/sysinfo"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
var (
|
||||
// it is override in tests
|
||||
getSystemInfo = defaultSysInfoImplementation
|
||||
)
|
||||
type SysInfoGetter interface {
|
||||
GetSysInfo() SysInfo
|
||||
}
|
||||
|
||||
type SysInfoWrapper struct {
|
||||
si sysinfo.SysInfo
|
||||
}
|
||||
|
||||
func (s SysInfoWrapper) GetSysInfo() SysInfo {
|
||||
s.si.GetSysInfo()
|
||||
return SysInfo{
|
||||
ChassisSerial: s.si.Chassis.Serial,
|
||||
ProductSerial: s.si.Product.Serial,
|
||||
BoardSerial: s.si.Board.Serial,
|
||||
ProductName: s.si.Product.Name,
|
||||
BoardName: s.si.Board.Name,
|
||||
ProductVendor: s.si.Product.Vendor,
|
||||
}
|
||||
}
|
||||
|
||||
// GetInfo retrieves and parses the system information
|
||||
func GetInfo(ctx context.Context) *Info {
|
||||
@@ -47,10 +65,12 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
si := SysInfoWrapper{}
|
||||
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
@@ -65,10 +85,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
UIVersion: extractUserAgent(ctx),
|
||||
KernelVersion: osInfo[1],
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
}
|
||||
|
||||
return gio
|
||||
@@ -88,9 +108,9 @@ func _getInfo() string {
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func sysInfo() (string, string, string) {
|
||||
func sysInfo(si SysInfo) (string, string, string) {
|
||||
isascii := regexp.MustCompile("^[[:ascii:]]+$")
|
||||
si := getSystemInfo()
|
||||
|
||||
serials := []string{si.ChassisSerial, si.ProductSerial}
|
||||
serial := ""
|
||||
|
||||
@@ -121,16 +141,3 @@ func sysInfo() (string, string, string) {
|
||||
}
|
||||
return serial, name, manufacturer
|
||||
}
|
||||
|
||||
func defaultSysInfoImplementation() SysInfo {
|
||||
si := sysinfo.SysInfo{}
|
||||
si.GetSysInfo()
|
||||
return SysInfo{
|
||||
ChassisSerial: si.Chassis.Serial,
|
||||
ProductSerial: si.Product.Serial,
|
||||
BoardSerial: si.Board.Serial,
|
||||
ProductName: si.Product.Name,
|
||||
BoardName: si.Board.Name,
|
||||
ProductVendor: si.Product.Vendor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,13 @@ import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/yusufpapurcu/wmi"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -41,10 +42,24 @@ func GetInfo(ctx context.Context) *Info {
|
||||
log.Warnf("failed to discover network addresses: %s", err)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
si := updateStaticInfo()
|
||||
if time.Since(start) > 1*time.Second {
|
||||
log.Warnf("updateStaticInfo took %s", time.Since(start))
|
||||
serialNum, err := sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
prodName, err := sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err := sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
env := Environment{
|
||||
Cloud: detect_cloud.Detect(ctx),
|
||||
Platform: detect_platform.Detect(ctx),
|
||||
}
|
||||
|
||||
gio := &Info{
|
||||
@@ -56,10 +71,10 @@ func GetInfo(ctx context.Context) *Info {
|
||||
CPUs: runtime.NumCPU(),
|
||||
KernelVersion: buildVersion,
|
||||
NetworkAddresses: addrs,
|
||||
SystemSerialNumber: si.SystemSerialNumber,
|
||||
SystemProductName: si.SystemProductName,
|
||||
SystemManufacturer: si.SystemManufacturer,
|
||||
Environment: si.Environment,
|
||||
SystemSerialNumber: serialNum,
|
||||
SystemProductName: prodName,
|
||||
SystemManufacturer: manufacturer,
|
||||
Environment: env,
|
||||
}
|
||||
|
||||
systemHostname, _ := os.Hostname()
|
||||
@@ -70,26 +85,6 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
func sysInfo() (serialNumber string, productName string, manufacturer string) {
|
||||
var err error
|
||||
serialNumber, err = sysNumber()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system serial number: %s", err)
|
||||
}
|
||||
|
||||
productName, err = sysProductName()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system product name: %s", err)
|
||||
}
|
||||
|
||||
manufacturer, err = sysManufacturer()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system manufacturer: %s", err)
|
||||
}
|
||||
|
||||
return serialNumber, productName, manufacturer
|
||||
}
|
||||
|
||||
func getOSNameAndVersion() (string, string) {
|
||||
var dst []Win32_OperatingSystem
|
||||
query := wmi.CreateQuery(&dst, "")
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
//go:build (linux && !android) || windows || (darwin && !ios)
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system/detect_cloud"
|
||||
"github.com/netbirdio/netbird/client/system/detect_platform"
|
||||
)
|
||||
|
||||
var (
|
||||
staticInfo StaticInfo
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
_ = updateStaticInfo()
|
||||
}()
|
||||
}
|
||||
|
||||
func updateStaticInfo() StaticInfo {
|
||||
once.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
})
|
||||
return staticInfo
|
||||
}
|
||||
@@ -183,10 +183,7 @@ func Test_sysInfo(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
getSystemInfo = func() SysInfo {
|
||||
return tt.sysInfo
|
||||
}
|
||||
gotSerialNum, gotProdName, gotManufacturer := sysInfo()
|
||||
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
|
||||
if gotSerialNum != tt.wantSerialNum {
|
||||
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
|
||||
}
|
||||
|
||||
126
funding.json
126
funding.json
@@ -1,126 +0,0 @@
|
||||
{
|
||||
"version": "v1.0.0",
|
||||
"entity": {
|
||||
"type": "organisation",
|
||||
"role": "owner",
|
||||
"name": "NetBird GmbH",
|
||||
"email": "hello@netbird.io",
|
||||
"phone": "",
|
||||
"description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.",
|
||||
"webpageUrl": {
|
||||
"url": "https://github.com/netbirdio"
|
||||
}
|
||||
},
|
||||
"projects": [
|
||||
{
|
||||
"guid": "netbird",
|
||||
"name": "NetBird",
|
||||
"description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.",
|
||||
"webpageUrl": {
|
||||
"url": "https://github.com/netbirdio/netbird"
|
||||
},
|
||||
"repositoryUrl": {
|
||||
"url": "https://github.com/netbirdio/netbird"
|
||||
},
|
||||
"licenses": [
|
||||
"BSD-3"
|
||||
],
|
||||
"tags": [
|
||||
"network-security",
|
||||
"vpn",
|
||||
"developer-tools",
|
||||
"ztna",
|
||||
"zero-trust",
|
||||
"remote-access",
|
||||
"wireguard",
|
||||
"peer-to-peer",
|
||||
"private-networking",
|
||||
"software-defined-networking"
|
||||
]
|
||||
}
|
||||
],
|
||||
"funding": {
|
||||
"channels": [
|
||||
{
|
||||
"guid": "github-sponsors",
|
||||
"type": "payment-provider",
|
||||
"address": "https://github.com/sponsors/netbirdio",
|
||||
"description": ""
|
||||
},
|
||||
{
|
||||
"guid": "bank-transfer",
|
||||
"type": "bank",
|
||||
"address": "",
|
||||
"description": "Contact us at hello@netbird.io for bank transfer details."
|
||||
}
|
||||
],
|
||||
"plans": [
|
||||
{
|
||||
"guid": "support-yearly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - Yearly",
|
||||
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 100000,
|
||||
"currency": "USD",
|
||||
"frequency": "yearly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-one-time-year",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - One Year",
|
||||
"description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 100000,
|
||||
"currency": "USD",
|
||||
"frequency": "one-time",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-one-time-monthly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - Monthly",
|
||||
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 10000,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "support-monthly",
|
||||
"status": "active",
|
||||
"name": "Support Open Source Development and Maintenance - One Month",
|
||||
"description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.",
|
||||
"amount": 10000,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"guid": "goodwill",
|
||||
"status": "active",
|
||||
"name": "Goodwill Plan",
|
||||
"description": "Pay anything you wish to show your goodwill for the project.",
|
||||
"amount": 0,
|
||||
"currency": "USD",
|
||||
"frequency": "monthly",
|
||||
"channels": [
|
||||
"github-sponsors",
|
||||
"bank-transfer"
|
||||
]
|
||||
}
|
||||
],
|
||||
"history": null
|
||||
}
|
||||
}
|
||||
15
go.mod
15
go.mod
@@ -25,7 +25,7 @@ require (
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.64.1
|
||||
google.golang.org/protobuf v1.34.2
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
@@ -71,6 +71,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/r3labs/diff/v3 v3.0.1
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
@@ -155,7 +156,7 @@ require (
|
||||
github.com/go-text/typesetting v0.1.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
@@ -210,6 +211,8 @@ require (
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/yuin/goldmark v1.7.1 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
@@ -224,11 +227,11 @@ require (
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||
k8s.io/apimachinery v0.26.2 // indirect
|
||||
)
|
||||
|
||||
@@ -236,7 +239,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
||||
30
go.sum
30
go.sum
@@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
|
||||
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
|
||||
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
@@ -521,14 +521,14 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
@@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
|
||||
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
|
||||
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
|
||||
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
@@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -1151,8 +1157,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE
|
||||
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 h1:AgADTJarZTBqgjiUzRgfaBchgYB3/WFTC80GPwsMcRI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
|
||||
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
|
||||
@@ -1189,8 +1195,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@@ -1232,8 +1238,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
|
||||
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=
|
||||
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
||||
@@ -530,7 +530,7 @@ renderCaddyfile() {
|
||||
{
|
||||
debug
|
||||
servers :80,:443 {
|
||||
protocols h1 h2c h3
|
||||
protocols h1 h2c
|
||||
}
|
||||
}
|
||||
|
||||
@@ -788,7 +788,6 @@ services:
|
||||
networks: [ netbird ]
|
||||
ports:
|
||||
- '443:443'
|
||||
- '443:443/udp'
|
||||
- '80:80'
|
||||
- '8080:8080'
|
||||
volumes:
|
||||
|
||||
@@ -110,10 +110,11 @@ type AccountManager interface {
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@@ -139,7 +140,7 @@ type AccountManager interface {
|
||||
HasConnectedChannel(peerID string) bool
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManager() idp.Manager
|
||||
@@ -965,9 +966,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
|
||||
}
|
||||
|
||||
// UserGroupsAddToPeers adds groups to all peers of user
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
userPeers := make(map[string]struct{})
|
||||
for pid, peer := range a.Peers {
|
||||
if peer.UserID == userID {
|
||||
@@ -981,8 +980,6 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[stri
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
groupPeers := make(map[string]struct{})
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
@@ -996,25 +993,16 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[stri
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
groupUpdates[gid] = difference(group.Peers, oldPeers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// UserGroupsRemoveFromPeers removes groups from all peers of user
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
for _, gid := range groups {
|
||||
group, ok := a.Groups[gid]
|
||||
if !ok || group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
peer, ok := a.Peers[pid]
|
||||
@@ -1026,10 +1014,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
}
|
||||
}
|
||||
group.Peers = update
|
||||
groupUpdates[gid] = difference(oldPeers, group.Peers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
@@ -1191,11 +1176,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("groups propagation failed: %w", err)
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@@ -1206,39 +1186,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
|
||||
if newSettings.GroupsPropagationEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
|
||||
// Todo: retroactively add user groups to all peers
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
} else {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1287,7 +1249,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Errorf("failed getting account %s expiring peers", accountID)
|
||||
log.Errorf("failed getting account %s expiring peers", account.Id)
|
||||
return account.GetNextInactivePeerExpiration()
|
||||
}
|
||||
|
||||
@@ -1473,7 +1435,7 @@ func isNil(i idp.Manager) bool {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2067,7 +2029,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
groups, err := transaction.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -2097,7 +2059,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
groups, err = transaction.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -2121,7 +2083,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -2139,7 +2101,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@@ -2152,7 +2114,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@@ -2165,19 +2127,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2333,12 +2290,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewGetAccountError(err)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
|
||||
@@ -2357,12 +2314,12 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return status.NewGetAccountError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -2378,9 +2335,6 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer unlockPeer()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2444,7 +2398,12 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
return
|
||||
}
|
||||
am.updateAccountPeers(ctx, updatedAccount)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
|
||||
@@ -6,17 +6,13 @@ import (
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -33,18 +29,14 @@ import (
|
||||
)
|
||||
|
||||
type MocIntegratedValidator struct {
|
||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||
if a.ValidatePeerFunc != nil {
|
||||
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||
}
|
||||
return update, false, nil
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
return update, nil
|
||||
}
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||
validatedPeers := make(map[string]struct{})
|
||||
@@ -986,110 +978,6 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "example.com",
|
||||
UserId: "pvt-domain-user",
|
||||
DomainCategory: PrivateCategory,
|
||||
}
|
||||
|
||||
publicClaims := jwtclaims.AuthorizationClaims{
|
||||
Domain: "test.com",
|
||||
UserId: "public-domain-user",
|
||||
DomainCategory: PublicCategory,
|
||||
}
|
||||
|
||||
am, err := createManager(b)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
return
|
||||
}
|
||||
id, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pid, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
users := genUsers("priv", 100)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), id)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
acc.Users = users
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), acc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
userP := genUsers("pub", 100)
|
||||
|
||||
pacc, err := am.Store.GetAccount(context.Background(), pid)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pacc.Users = userP
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), pacc)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.Run("public without account ID", func(b *testing.B) {
|
||||
// b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("private without account ID", func(b *testing.B) {
|
||||
// b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("private with account ID", func(b *testing.B) {
|
||||
claims.AccountId = id
|
||||
// b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func genUsers(p string, n int) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
now := time.Now()
|
||||
for i := 0; i < n; i++ {
|
||||
users[fmt.Sprintf("%s-%d", p, i)] = &User{
|
||||
Id: fmt.Sprintf("%s-%d", p, i),
|
||||
Role: UserRoleAdmin,
|
||||
LastLogin: now,
|
||||
CreatedAt: now,
|
||||
Issued: "api",
|
||||
AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"},
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeer(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -1242,7 +1130,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1253,7 +1142,8 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1322,6 +1212,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1334,19 +1237,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1367,7 +1258,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1378,8 +1269,9 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1413,20 +1305,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -1437,8 +1322,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
}
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1461,7 +1352,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -2715,7 +2606,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2735,7 +2626,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2774,7 +2665,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
@@ -2986,218 +2877,3 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 1, 3, 4, 10},
|
||||
{"Medium", 500, 100, 7, 13, 10, 60},
|
||||
{"Large", 5000, 200, 65, 80, 60, 170},
|
||||
{"Small single", 50, 10, 1, 3, 4, 60},
|
||||
{"Medium single", 500, 10, 7, 13, 10, 26},
|
||||
{"Large 5", 5000, 15, 65, 80, 60, 170},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 102, 110, 102, 120},
|
||||
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||
{"Large", 5000, 200, 160, 200, 160, 270},
|
||||
{"Small single", 50, 10, 102, 110, 102, 120},
|
||||
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||
{"Large 5", 5000, 15, 160, 200, 160, 270},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
UserID: "regular_user",
|
||||
SetupKey: "",
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||
minMsPerOpLocal float64
|
||||
maxMsPerOpLocal float64
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 107, 120, 107, 140},
|
||||
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||
{"Large", 5000, 200, 180, 220, 180, 320},
|
||||
{"Small single", 50, 10, 107, 120, 105, 140},
|
||||
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||
{"Large 5", 5000, 15, 180, 220, 180, 320},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||
SSHKey: "someKey",
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||
UserID: "regular_user",
|
||||
SetupKey: "",
|
||||
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||
})
|
||||
assert.NoError(b, err)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
}
|
||||
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,9 +148,6 @@ const (
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
|
||||
SetupKeyDeleted Activity = 68
|
||||
|
||||
UserGroupPropagationEnabled Activity = 69
|
||||
UserGroupPropagationDisabled Activity = 70
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@@ -225,9 +222,6 @@ var activityMap = map[Activity]Code{
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
|
||||
|
||||
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
|
||||
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
82
management/server/differs/netip.go
Normal file
82
management/server/differs/netip.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package differs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
|
||||
"github.com/r3labs/diff/v3"
|
||||
)
|
||||
|
||||
// NetIPAddr is a custom differ for netip.Addr
|
||||
type NetIPAddr struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromAddr, ok1 := a.Interface().(netip.Addr)
|
||||
toAddr, ok2 := b.Interface().(netip.Addr)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromAddr.String() != toAddr.String() {
|
||||
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
|
||||
// NetIPPrefix is a custom differ for netip.Prefix
|
||||
type NetIPPrefix struct {
|
||||
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
|
||||
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
|
||||
if a.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.CREATE, path, nil, b.Interface())
|
||||
return nil
|
||||
}
|
||||
if b.Kind() == reflect.Invalid {
|
||||
cl.Add(diff.DELETE, path, a.Interface(), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
fromPrefix, ok1 := a.Interface().(netip.Prefix)
|
||||
toPrefix, ok2 := b.Interface().(netip.Prefix)
|
||||
if !ok1 || !ok2 {
|
||||
return fmt.Errorf("invalid type for netip.Addr")
|
||||
}
|
||||
|
||||
if fromPrefix.String() != toPrefix.String() {
|
||||
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
|
||||
differ.DiffFunc = dfunc //nolint
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@@ -86,12 +85,8 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
@@ -99,137 +94,64 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
||||
}
|
||||
|
||||
if dnsSettingsToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range addedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
for _, groupID := range removedGroups {
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
|
||||
@@ -8,10 +8,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -522,64 +521,23 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
|
||||
t.Run("creating dns setting with unused groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("saving dns setting with used groups", func(t *testing.T) {
|
||||
@@ -601,6 +559,27 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{"groupA", "groupB"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
|
||||
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
// It is recommended to call it with locking FileStore.mux
|
||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
start := time.Now()
|
||||
err := util.WriteJson(context.Background(), file, s)
|
||||
err := util.WriteJson(file, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -6,11 +6,10 @@ import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
@@ -28,17 +27,18 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -49,7 +49,8 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -57,12 +58,13 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
|
||||
return am.Store.GetAccountGroups(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@@ -75,74 +77,79 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
|
||||
// SaveGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
for _, newGroup := range newGroups {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
// Avoid duplicate groups only for the API issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
oldGroup := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
|
||||
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
newGroupIDs := make([]string, 0, len(newGroups))
|
||||
for _, newGroup := range newGroups {
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
if oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
@@ -152,42 +159,35 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
})
|
||||
}
|
||||
|
||||
modifiedPeers := slices.Concat(addedPeers, removedPeers)
|
||||
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
for _, peerID := range removedPeers {
|
||||
peer, ok := peers[peerID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
|
||||
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -210,10 +210,42 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
@@ -222,94 +254,93 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*nbgroup.Group
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
delete(account.Groups, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, group := range deletedGroups {
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
for _, item := range account.Groups {
|
||||
groups = append(groups, item)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
add := true
|
||||
for _, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
add = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if add {
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -320,162 +351,90 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
|
||||
})
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNewGroup validates the new group for existence and required fields.
|
||||
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
|
||||
account.Network.IncSerial()
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Prevent duplicate groups for API-issued groups.
|
||||
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
|
||||
if existingGroup != nil {
|
||||
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
|
||||
}
|
||||
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
return checkGroupLinkedToSettings(ctx, transaction, group)
|
||||
}
|
||||
|
||||
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
|
||||
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
@@ -487,13 +446,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
@@ -501,18 +454,11 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
|
||||
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
@@ -522,13 +468,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
|
||||
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
@@ -537,36 +477,8 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
@@ -575,18 +487,21 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s
|
||||
return false
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() {
|
||||
return true, nil
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -49,35 +49,3 @@ func (g *Group) Copy() *Group {
|
||||
func (g *Group) HasPeers() bool {
|
||||
return len(g.Peers) > 0
|
||||
}
|
||||
|
||||
// IsGroupAll checks if the group is a default "All" group.
|
||||
func (g *Group) IsGroupAll() bool {
|
||||
return g.Name == "All"
|
||||
}
|
||||
|
||||
// AddPeer adds peerID to Peers if not present, returning true if added.
|
||||
func (g *Group) AddPeer(peerID string) bool {
|
||||
if peerID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
g.Peers = append(g.Peers, peerID)
|
||||
return true
|
||||
}
|
||||
|
||||
// RemovePeer removes peerID from Peers if present, returning true if removed.
|
||||
func (g *Group) RemovePeer(peerID string) bool {
|
||||
for i, itemID := range g.Peers {
|
||||
if itemID == peerID {
|
||||
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package group
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddPeer(t *testing.T) {
|
||||
t.Run("add new peer to empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add new peer to non-empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.True(t, group.AddPeer(peerID))
|
||||
assert.Contains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("add duplicate peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("add empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.AddPeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemovePeer(t *testing.T) {
|
||||
t.Run("remove existing peer from slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
|
||||
peerID := "peer2"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from empty slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{}}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from nil slice", func(t *testing.T) {
|
||||
group := &Group{Peers: nil}
|
||||
peerID := "peer1"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Nil(t, group.Peers)
|
||||
})
|
||||
|
||||
t.Run("remove non-existent peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := "peer3"
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
|
||||
t.Run("remove peer from single-item slice", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1"}}
|
||||
peerID := "peer1"
|
||||
assert.True(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 0, len(group.Peers))
|
||||
assert.NotContains(t, group.Peers, peerID)
|
||||
})
|
||||
|
||||
t.Run("remove empty peer", func(t *testing.T) {
|
||||
group := &Group{Peers: []string{"peer1", "peer2"}}
|
||||
peerID := ""
|
||||
assert.False(t, group.RemovePeer(peerID))
|
||||
assert.Equal(t, 2, len(group.Peers))
|
||||
})
|
||||
}
|
||||
@@ -8,13 +8,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -208,7 +207,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
{
|
||||
name: "delete non-existent group",
|
||||
groupIDs: []string{"non-existent-group"},
|
||||
expectedReasons: []string{"group: non-existent-group not found"},
|
||||
expectedDeleted: []string{"non-existent-group"},
|
||||
},
|
||||
{
|
||||
name: "delete multiple groups with mixed results",
|
||||
@@ -500,7 +499,8 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
@@ -511,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Saving a group linked to policy should update account peers and send peer update
|
||||
@@ -536,6 +536,29 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Saving an unchanged group should trigger account peers update and not send peer update
|
||||
// since there is no change in the network map
|
||||
t.Run("saving unchanged group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
// adding peer to a used group should update account peers and send peer update
|
||||
t.Run("adding peer to linked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -39,7 +38,6 @@ type GRPCServer struct {
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -150,13 +148,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
@@ -180,7 +171,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
@@ -200,15 +190,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@@ -259,18 +245,10 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
|
||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||
defer unlock()
|
||||
|
||||
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
}
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
@@ -296,24 +274,6 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.peerLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released peer lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
|
||||
@@ -439,13 +439,17 @@ components:
|
||||
example: 5
|
||||
required:
|
||||
- accessible_peers_count
|
||||
SetupKeyBase:
|
||||
SetupKey:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: Setup Key ID
|
||||
type: string
|
||||
example: 2531583362
|
||||
key:
|
||||
description: Setup Key value
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
name:
|
||||
description: Setup key name identifier
|
||||
type: string
|
||||
@@ -514,31 +518,22 @@ components:
|
||||
- updated_at
|
||||
- usage_limit
|
||||
- ephemeral
|
||||
SetupKeyClear:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as plain text
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
required:
|
||||
- key
|
||||
SetupKey:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as secret
|
||||
type: string
|
||||
example: A6160****
|
||||
required:
|
||||
- key
|
||||
SetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Setup Key name
|
||||
type: string
|
||||
example: Default key
|
||||
type:
|
||||
description: Setup key type, one-off for single time usage and reusable
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds, 0 will mean the key never expires
|
||||
type: integer
|
||||
minimum: 0
|
||||
example: 86400
|
||||
revoked:
|
||||
description: Setup key revocation status
|
||||
type: boolean
|
||||
@@ -549,9 +544,21 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||
usage_limit:
|
||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
type: integer
|
||||
example: 0
|
||||
ephemeral:
|
||||
description: Indicate that the peer will be ephemeral or not
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
CreateSetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1936,7 +1943,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetupKeyClear'
|
||||
$ref: '#/components/schemas/SetupKey'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
|
||||
@@ -1062,94 +1062,7 @@ type SetupKey struct {
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key as secret
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyBase defines model for SetupKeyBase.
|
||||
type SetupKeyBase struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyClear defines model for SetupKeyClear.
|
||||
type SetupKeyClear struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key as plain text
|
||||
// Key Setup Key value
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
@@ -1185,8 +1098,23 @@ type SetupKeyRequest struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// User defines model for User.
|
||||
|
||||
@@ -184,26 +184,14 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groupsMap := map[string]*nbgroup.Group{}
|
||||
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
@@ -316,7 +304,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
groupsInfo := []api.GroupMinimum{}
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsChecked := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@@ -120,22 +122,21 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
return
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
isUpdate := policyID != ""
|
||||
|
||||
if policyID == "" {
|
||||
policyID = xid.New().String()
|
||||
}
|
||||
|
||||
policy := server.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
}
|
||||
for _, rule := range req.Rules {
|
||||
var ruleID string
|
||||
if rule.Id != nil {
|
||||
ruleID = *rule.Id
|
||||
}
|
||||
|
||||
pr := server.PolicyRule{
|
||||
ID: ruleID,
|
||||
PolicyID: policyID,
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
Name: rule.Name,
|
||||
Destinations: rule.Destinations,
|
||||
Sources: rule.Sources,
|
||||
@@ -224,8 +225,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
|
||||
if err != nil {
|
||||
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(allGroups, policy)
|
||||
resp := toPolicyResponse(allGroups, &policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
|
||||
@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return policy, nil
|
||||
return nil
|
||||
},
|
||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||
|
||||
@@ -169,8 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
|
||||
if err != nil {
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
return nil
|
||||
},
|
||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||
_, ok := testPostureChecks[postureChecksID]
|
||||
|
||||
@@ -149,7 +149,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
}
|
||||
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peers_group' should be provided")
|
||||
}
|
||||
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
|
||||
@@ -137,6 +137,11 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
@@ -145,6 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey := &server.SetupKey{}
|
||||
newKey.AutoGroups = req.AutoGroups
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
|
||||
@@ -52,23 +52,26 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if len(groups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, group := range groups {
|
||||
var found bool
|
||||
for _, accountGroup := range accountsGroups {
|
||||
if accountGroup.ID == group {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||
type IntegratedValidator interface {
|
||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
||||
|
||||
@@ -77,8 +77,6 @@ type JWTValidator struct {
|
||||
options Options
|
||||
}
|
||||
|
||||
var keyNotFound = errors.New("unable to find appropriate key")
|
||||
|
||||
// NewJWTValidator constructor
|
||||
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
|
||||
keys, err := getPemKeys(ctx, keysLocation)
|
||||
@@ -126,18 +124,12 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
|
||||
}
|
||||
|
||||
publicKey, err := getPublicKey(ctx, token, keys)
|
||||
if err == nil {
|
||||
return publicKey, nil
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("getPublicKey error: %s", err)
|
||||
if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled {
|
||||
msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Error(msg)
|
||||
|
||||
return nil, err
|
||||
return publicKey, nil
|
||||
},
|
||||
EnableAuthOnOptions: false,
|
||||
}
|
||||
@@ -237,7 +229,7 @@ func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{
|
||||
log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
|
||||
}
|
||||
|
||||
return nil, keyNotFound
|
||||
return nil, errors.New("unable to find appropriate key")
|
||||
}
|
||||
|
||||
func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
|
||||
@@ -318,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user