Compare commits

..

1 Commits

Author SHA1 Message Date
Maycon Santos
840b07c784 add todos 2024-07-05 11:15:28 +02:00
137 changed files with 1567 additions and 5277 deletions

View File

@@ -1,8 +0,0 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
[*.go]
indent_style = tab

View File

@@ -31,14 +31,9 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**NetBird status -dA output:**
**NetBird status -d output:**
If applicable, add the `netbird status -dA' command output.
**Do you face any client issues on desktop?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
If applicable, add the `netbird status -d' command output.
**Screenshots**

View File

@@ -13,7 +13,7 @@ concurrency:
jobs:
test:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
@@ -21,26 +21,19 @@ jobs:
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.1"
prepare: |
pkg install -y go
pkg install -y curl
pkg install -y git
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...
set -x
curl -o go.tar.gz https://go.dev/dl/go1.21.11.freebsd-amd64.tar.gz -L
tar zxf go.tar.gz
mv go /usr/local/go
ln -s /usr/local/go/bin/go /usr/local/bin/go
go mod tidy
go test -timeout 5m -p 1 ./iface/...
go test -timeout 5m -p 1 ./client/...
cd client
go build .
cd ..

View File

@@ -6,15 +6,12 @@ on:
- 'v*'
branches:
- main
- validate-icon
pull_request:
env:
SIGN_PIPE_VER: "v0.0.12"
SIGN_PIPE_VER: "v0.0.11"
GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -26,13 +23,6 @@ jobs:
env:
flags: ""
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
-
@@ -78,11 +68,18 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc amd64
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
@@ -94,28 +91,28 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: release
path: dist/
retention-days: 3
-
name: upload linux packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
-
name: upload windows packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
-
name: upload macos packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -124,13 +121,6 @@ jobs:
release_ui:
runs-on: ubuntu-latest
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
@@ -161,11 +151,10 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Install rsrc
run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows rsrc
run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
@@ -177,7 +166,7 @@ jobs:
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: release-ui
path: dist/
@@ -226,7 +215,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: release-ui-darwin
path: dist/

View File

@@ -151,10 +151,10 @@ jobs:
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
docker compose up -d
docker-compose up -d
sleep 5
docker compose ps
docker compose logs --tail=20
docker-compose ps
docker-compose logs --tail=20
- name: test running containers
run: |
@@ -207,7 +207,7 @@ jobs:
- name: Postgres run cleanup
run: |
docker compose down --volumes --rmi all
docker-compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB

View File

@@ -11,6 +11,8 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows

View File

@@ -10,12 +10,10 @@
<img width="234" src="docs/media/logo-full.png"/>
</p>
<p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>

View File

@@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)

View File

@@ -178,21 +178,6 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
})
}
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
// domain names with random strings.
func (a *Anonymizer) AnonymizeRoute(route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@@ -14,8 +13,6 @@ import (
"github.com/netbirdio/netbird/client/server"
)
const errCloseConnection = "Failed to close connection: %v"
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
@@ -66,17 +63,12 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
@@ -92,11 +84,7 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
@@ -125,11 +113,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
@@ -138,20 +122,17 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(time.Second * 10)
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
@@ -164,11 +145,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
@@ -186,25 +162,21 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
cmd.Println("\nDuration completed")
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
time.Sleep(1 * time.Second)
if restoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
cmd.Println("Netbird up")
}
if !initialLevelTrace {
@@ -214,6 +186,16 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil

View File

@@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -39,11 +39,6 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service
if logFile == "console" {
err = handleRebrand(cmd)
@@ -67,7 +62,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -86,7 +81,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
SetupKey: setupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,

View File

@@ -37,7 +37,6 @@ const (
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
)
var (
@@ -56,7 +55,6 @@ var (
managementURL string
adminURL string
setupKey string
setupKeyPath string
hostName string
preSharedKey string
natExternalIPs []string
@@ -71,7 +69,6 @@ var (
autoConnectDisabled bool
extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{
@@ -94,15 +91,12 @@ func init() {
oldDefaultConfigPathDir = "/etc/wiretrustee/"
oldDefaultLogFileDir = "/var/log/wiretrustee/"
switch runtime.GOOS {
case "windows":
if runtime.GOOS == "windows" {
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
case "freebsd":
defaultConfigPathDir = "/var/db/netbird/"
}
defaultConfigPath = defaultConfigPathDir + "config.json"
@@ -127,10 +121,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
@@ -173,8 +165,6 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success
@@ -256,21 +246,6 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock,
}
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error {
var err error
if logFile == defaultLogFile {

View File

@@ -31,8 +31,6 @@ var installCmd = &cobra.Command{
configPath,
"--log-level",
logLevel,
"--daemon-addr",
daemonAddr,
}
if managementURL != "" {

View File

@@ -807,7 +807,7 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
peer.Routes[i] = anonymizeRoute(a, route)
}
}
@@ -843,8 +843,21 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
overview.Routes[i] = anonymizeRoute(a, route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}
func anonymizeRoute(a *anonymize.Anonymizer, route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}

View File

@@ -11,7 +11,6 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@@ -72,7 +71,6 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
@@ -90,11 +88,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil {
t.Fatal(err)
}

View File

@@ -147,11 +147,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
@@ -159,7 +154,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -204,13 +199,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
SetupKey: setupKey,
ManagementUrl: managementURL,
AdminURL: adminURL,
NatExternalIPs: natExternalIPs,

View File

@@ -2,7 +2,6 @@ package cmd
import (
"context"
"os"
"testing"
"time"
@@ -41,36 +40,6 @@ func TestUpDaemon(t *testing.T) {
return
}
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{
"up",
"--daemon-addr", "tcp://" + cliAddr,

View File

@@ -337,6 +337,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true
}
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true
}

View File

@@ -69,11 +69,6 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil {
// fallback to device code flow
@@ -86,7 +81,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
@@ -144,18 +143,6 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
cert := p.providerConfig.ClientCertPair
if cert != nil {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*cert},
},
}
sslClient := &http.Client{Transport: tr}
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
req = req.WithContext(ctx)
}
token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)

View File

@@ -2,7 +2,6 @@ package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
"os"
@@ -58,8 +57,6 @@ type ConfigInput struct {
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
}
// Config Configuration type
@@ -105,13 +102,6 @@ type Config struct {
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication
ClientCertPath string
//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -395,26 +385,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}
if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}
return updated, nil
}

View File

@@ -15,12 +15,6 @@ type hostManager interface {
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
}
type SystemDNSSettings struct {
Domains []string
ServerIP string
ServerPort int
}
type HostDNSConfig struct {
Domains []DomainConfig `json:"domains"`
RouteAll bool `json:"routeAll"`

View File

@@ -7,7 +7,6 @@ import (
"bytes"
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
@@ -19,7 +18,7 @@ import (
const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch"
keyServerAddresses = "ServerAddresses"
@@ -29,12 +28,12 @@ const (
scutilPath = "/usr/sbin/scutil"
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"
)
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
// primaryServiceID primary interface in the system. AKA the interface with the default route
primaryServiceID string
createdKeys map[string]struct{}
}
func newHostManager() (hostManager, error) {
@@ -50,6 +49,20 @@ func (s *systemConfigurator) supportCustomPort() bool {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error
if config.RouteAll {
err = s.addDNSSetupForAll(config.ServerIP, config.ServerPort)
if err != nil {
return fmt.Errorf("add dns setup for all: %w", err)
}
} else if s.primaryServiceID != "" {
err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
if err != nil {
return fmt.Errorf("remote key from system config: %w", err)
}
s.primaryServiceID = ""
log.Infof("removed %s:%d as main DNS resolver for this peer", config.ServerIP, config.ServerPort)
}
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
@@ -60,19 +73,6 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
matchDomains []string
)
err = s.recordSystemDNSSettings(true)
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
}
if config.RouteAll {
searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS()
if err != nil {
log.Infof("failed to enable split DNS")
}
}
for _, dConf := range config.Domains {
if dConf.Disabled {
continue
@@ -110,17 +110,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
func (s *systemConfigurator) restoreHostDNS() error {
keys := s.getRemovableKeysWithDefaults()
for _, key := range keys {
lines := ""
for key := range s.createdKeys {
lines += buildRemoveKeyOperation(key)
keyType := "search"
if strings.Contains(key, matchSuffix) {
keyType = "match"
}
log.Infof("removing %s domains from system", keyType)
err := s.removeKeyFromSystemConfig(key)
if err != nil {
log.Errorf("failed to remove %s domains from system: %s", keyType, err)
}
}
if s.primaryServiceID != "" {
lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID))
log.Infof("restoring DNS resolver configuration for system")
}
_, err := runSystemConfigCommand(wrapCommand(lines))
if err != nil {
log.Errorf("got an error while cleaning the system configuration: %s", err)
return fmt.Errorf("clean system: %w", err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
@@ -130,19 +136,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
return nil
}
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
for key := range s.createdKeys {
keys = append(keys, key)
}
return keys
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
@@ -155,97 +148,6 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
return nil
}
func (s *systemConfigurator) addLocalDNS() error {
if s.systemDNSSettings.ServerIP == "" || len(s.systemDNSSettings.Domains) == 0 {
err := s.recordSystemDNSSettings(true)
log.Errorf("Unable to get system DNS configuration")
return err
}
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server")
}
return nil
}
func (s *systemConfigurator) recordSystemDNSSettings(force bool) error {
if s.systemDNSSettings.ServerIP != "" && len(s.systemDNSSettings.Domains) != 0 && !force {
return nil
}
systemDNSSettings, err := s.getSystemDNSSettings()
if err != nil {
return fmt.Errorf("couldn't get current DNS config: %w", err)
}
s.systemDNSSettings = systemDNSSettings
return nil
}
func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
primaryServiceKey, _, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err)
}
dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey)
line := buildCommandLine("show", dnsServiceKey, "")
stdinCommands := wrapCommand(line)
b, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err)
}
var dnsSettings SystemDNSSettings
inSearchDomainsArray := false
inServerAddressesArray := false
scanner := bufio.NewScanner(bytes.NewReader(b))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
switch {
case strings.HasPrefix(line, "DomainName :"):
domainName := strings.TrimSpace(strings.Split(line, ":")[1])
dnsSettings.Domains = append(dnsSettings.Domains, domainName)
case line == "SearchDomains : <array> {":
inSearchDomainsArray = true
continue
case line == "ServerAddresses : <array> {":
inServerAddressesArray = true
continue
case line == "}":
inSearchDomainsArray = false
inServerAddressesArray = false
}
if inSearchDomainsArray {
searchDomain := strings.Split(line, " : ")[1]
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
dnsSettings.ServerIP = address
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
}
}
}
if err := scanner.Err(); err != nil {
return dnsSettings, err
}
// default to 53 port
dnsSettings.ServerPort = 53
return dnsSettings, nil
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
@@ -292,6 +194,23 @@ func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port
return nil
}
func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error {
primaryServiceKey, existingNameserver, err := s.getPrimaryService()
if err != nil || primaryServiceKey == "" {
return fmt.Errorf("couldn't find the primary service key: %w", err)
}
err = s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port, existingNameserver)
if err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port)
s.primaryServiceID = primaryServiceKey
return nil
}
func (s *systemConfigurator) getPrimaryService() (string, string, error) {
line := buildCommandLine("show", globalIPv4State, "")
stdinCommands := wrapCommand(line)
@@ -320,6 +239,19 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil
}
func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int, existingDNSServer string) error {
lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0))
lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer+" "+existingDNSServer)
lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port))
addDomainCommand := buildCreateStateWithOperation(setupKey, lines)
stdinCommands := wrapCommand(addDomainCommand)
_, err := runSystemConfigCommand(stdinCommands)
if err != nil {
return fmt.Errorf("applying dns setup, error: %w", err)
}
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@@ -94,7 +94,7 @@ func NewDefaultServer(
var dnsService service
if wgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(wgInterface)
dnsService = newServiceViaMemory(wgInterface)
} else {
dnsService = newServiceViaListener(wgInterface, addrPort)
}
@@ -112,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
@@ -130,7 +130,7 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager
return ds
}

View File

@@ -534,7 +534,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
service: NewServiceViaMemory(&mocWGIface{}),
service: newServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},

View File

@@ -128,9 +128,6 @@ func (s *serviceViaListener) RuntimeIP() string {
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
)
type ServiceViaMemory struct {
type serviceViaMemory struct {
wgInterface WGIface
dnsMux *dns.ServeMux
runtimeIP string
@@ -22,8 +22,8 @@ type ServiceViaMemory struct {
listenerFlagLock sync.Mutex
}
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
s := &ServiceViaMemory{
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
s := &serviceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
@@ -33,7 +33,7 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
return s
}
func (s *ServiceViaMemory) Listen() error {
func (s *serviceViaMemory) Listen() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
@@ -52,7 +52,7 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
func (s *ServiceViaMemory) Stop() {
func (s *serviceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
@@ -67,23 +67,23 @@ func (s *ServiceViaMemory) Stop() {
s.listenerIsRunning = false
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
s.dnsMux.Handle(pattern, handler)
}
func (s *ServiceViaMemory) DeregisterMux(pattern string) {
func (s *serviceViaMemory) DeregisterMux(pattern string) {
s.dnsMux.HandleRemove(pattern)
}
func (s *ServiceViaMemory) RuntimePort() int {
func (s *serviceViaMemory) RuntimePort() int {
return s.runtimePort
}
func (s *ServiceViaMemory) RuntimeIP() string {
func (s *serviceViaMemory) RuntimeIP() string {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")

View File

@@ -24,7 +24,7 @@ const (
probeTimeout = 2 * time.Second
)
const testRecord = "com."
const testRecord = "."
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
@@ -42,7 +42,6 @@ type upstreamResolverBase struct {
upstreamServers []string
disabled bool
failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
@@ -125,7 +124,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
@@ -174,11 +172,6 @@ func (u *upstreamResolverBase) probeAvailability() {
default:
}
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
return
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
@@ -190,7 +183,7 @@ func (u *upstreamResolverBase) probeAvailability() {
wg.Add(1)
go func() {
defer wg.Done()
err := u.testNameserver(upstream, 500*time.Millisecond)
err := u.testNameserver(upstream)
if err != nil {
errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
@@ -231,7 +224,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
if err := u.testNameserver(upstream); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
@@ -251,7 +244,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServers)
u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate()
u.disabled = false
}
@@ -273,14 +265,13 @@ func (u *upstreamResolverBase) disable(err error) {
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) testNameserver(server string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
func (u *upstreamResolverBase) testNameserver(server string) error {
ctx, cancel := context.WithTimeout(u.ctx, probeTimeout)
defer cancel()
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)

View File

@@ -4,7 +4,6 @@ package dns
import (
"context"
"fmt"
"net"
"syscall"
"time"
@@ -18,9 +17,9 @@ import (
type upstreamResolverIOS struct {
*upstreamResolverBase
lIP net.IP
lNet *net.IPNet
interfaceName string
lIP net.IP
lNet *net.IPNet
iIndex int
}
func newUpstreamResolver(
@@ -33,11 +32,17 @@ func newUpstreamResolver(
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder)
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
interfaceName: interfaceName,
iIndex: index,
}
ios.upstreamClient = ios
@@ -48,7 +53,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
client := &dns.Client{}
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
log.Errorf("error while parsing upstream host: %s", err)
}
timeout := upstreamTimeout
@@ -60,35 +65,26 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
client = u.getClientPrivate(timeout)
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName)
if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
return nil, err
}
func (u *upstreamResolverIOS) getClientPrivate(dialTimeout time.Duration) *dns.Client {
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: ip,
IP: u.lIP,
Port: 0, // Let the OS pick a free port
},
Timeout: dialTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var operr error
fn := func(s uintptr) {
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, index)
operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
}
if err := c.Control(fn); err != nil {
@@ -105,7 +101,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
client := &dns.Client{
Dialer: dialer,
}
return client, nil
return client
}
func getInterfaceIndex(interfaceName string) (int, error) {

View File

@@ -266,23 +266,8 @@ func (e *Engine) Stop() error {
e.close()
e.wgConnWorker.Wait()
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}
select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
log.Infof("stopped Netbird Engine")
return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@@ -1548,20 +1533,3 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}
if iface.Flags&net.FlagUp != 0 {
return true
}
return false
}

View File

@@ -36,7 +36,6 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto"
@@ -1070,11 +1069,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}

View File

@@ -4,7 +4,6 @@ package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
@@ -22,20 +21,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
if err := unix.Close(fd); err != nil {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()
for {
select {
case <-ctx.Done():
@@ -44,9 +34,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
continue
}
if n < unix.SizeofRtMsghdr {

View File

@@ -99,11 +99,6 @@ func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []syste
return false
}
if isSoftInterface(nexthop.Intf.Name) {
log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
return false
}
unspec := getUnspecifiedPrefix(nexthop.IP)
defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
@@ -124,7 +119,7 @@ func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
func processRoutes(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
var defaultRoutes []string
foundMatchingRoute := false
@@ -133,7 +128,7 @@ func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes
routeInfo := formatRouteInfo(r)
defaultRoutes = append(defaultRoutes, routeInfo)
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
if r.Nexthop == nexthop.IP && compareIntf(r.Interface, intf) == 0 {
foundMatchingRoute = true
log.Debugf("network monitor: found matching default route: %s", routeInfo)
}
@@ -237,18 +232,14 @@ func stateFromInt(state uint8) string {
}
func compareIntf(a, b *net.Interface) int {
switch {
case a == nil && b == nil:
if a == nil && b == nil {
return 0
case a == nil:
return -1
case b == nil:
return 1
default:
return a.Index - b.Index
}
}
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
if a == nil {
return -1
}
if b == nil {
return 1
}
return a.Index - b.Index
}

View File

@@ -2,7 +2,6 @@ package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
@@ -37,12 +36,10 @@ type PKCEAuthProviderConfig struct {
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
//ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
@@ -96,7 +93,6 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
},
}

View File

@@ -3,14 +3,12 @@ package routemanager
import (
"context"
"fmt"
"reflect"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@@ -66,7 +64,7 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder),
}
return client
}
@@ -311,33 +309,22 @@ func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
}()
}
func (c *clientNetwork) handleUpdate(update routesUpdate) bool {
isUpdateMapDifferent := false
func (c *clientNetwork) handleUpdate(update routesUpdate) {
updateMap := make(map[route.ID]*route.Route)
for _, r := range update.routes {
updateMap[r.ID] = r
}
if len(c.routes) != len(updateMap) {
isUpdateMapDifferent = true
}
for id, r := range c.routes {
_, found := updateMap[id]
if !found {
close(c.routePeersNotifiers[r.Peer])
delete(c.routePeersNotifiers, r.Peer)
isUpdateMapDifferent = true
continue
}
if !reflect.DeepEqual(c.routes[id], updateMap[id]) {
isUpdateMapDifferent = true
}
}
c.routes = updateMap
return isUpdateMapDifferent
}
// peersStateAndUpdateWatcher is the main point of reacting on client network routing events.
@@ -364,19 +351,13 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
log.Debugf("Received a new client network route update for [%v]", c.handler)
// hash update somehow
isTrueRouteUpdate := c.handleUpdate(update)
c.handleUpdate(update)
c.updateSerial = update.updateSerial
if isTrueRouteUpdate {
log.Debug("Client network update contains different routes, recalculating routes")
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
} else {
log.Debug("Route update is not different, skipping route recalculation")
err := c.recalculateRouteAndUpdatePeerAndSystem()
if err != nil {
log.Errorf("Failed to recalculate routes for network [%v]: %v", c.handler, err)
}
c.startPeersStatusChangeWatcher()
@@ -384,10 +365,9 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status) RouteHandler {
if rt.IsDynamic() {
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder)
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -48,8 +47,6 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
wgInterface *iface.WGIface
resolverAddr string
}
func NewRoute(
@@ -58,8 +55,6 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface *iface.WGIface,
resolverAddr string,
) *Route {
return &Route{
route: rt,
@@ -68,8 +63,6 @@ func NewRoute(
interval: interval,
dynamicDomains: domainMap{},
statusRecorder: statusRecorder,
wgInterface: wgInterface,
resolverAddr: resolverAddr,
}
}
@@ -196,14 +189,9 @@ func (r *Route) startResolver(ctx context.Context) {
}
func (r *Route) update(ctx context.Context) error {
resolved, err := r.resolveDomains()
if err != nil {
if len(resolved) == 0 {
return fmt.Errorf("resolve domains: %w", err)
}
log.Warnf("Failed to resolve domains: %v", err)
}
if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
if resolved, err := r.resolveDomains(); err != nil {
return fmt.Errorf("resolve domains: %w", err)
} else if err := r.updateDynamicRoutes(ctx, resolved); err != nil {
return fmt.Errorf("update dynamic routes: %w", err)
}
@@ -235,17 +223,11 @@ func (r *Route) resolve(results chan resolveResult) {
wg.Add(1)
go func(domain domain.Domain) {
defer wg.Done()
ips, err := r.getIPsFromResolver(domain)
ips, err := net.LookupIP(string(domain))
if err != nil {
log.Tracef("Failed to resolve domain %s with private resolver: %v", domain.SafeString(), err)
ips, err = net.LookupIP(string(domain))
if err != nil {
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
results <- resolveResult{domain: domain, err: fmt.Errorf("resolve d %s: %w", domain.SafeString(), err)}
return
}
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {

View File

@@ -1,13 +0,0 @@
//go:build !ios
package dynamic
import (
"net"
"github.com/netbirdio/netbird/management/domain"
)
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
return net.LookupIP(string(domain))
}

View File

@@ -1,55 +0,0 @@
//go:build ios
package dynamic
import (
"fmt"
"net"
"time"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/management/domain"
)
const dialTimeout = 10 * time.Second
func (r *Route) getIPsFromResolver(domain domain.Domain) ([]net.IP, error) {
privateClient, err := nbdns.GetClientPrivate(r.wgInterface.Address().IP, r.wgInterface.Name(), dialTimeout)
if err != nil {
return nil, fmt.Errorf("error while creating private client: %s", err)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(string(domain)), dns.TypeA)
startTime := time.Now()
response, _, err := privateClient.Exchange(msg, r.resolverAddr)
if err != nil {
return nil, fmt.Errorf("DNS query for %s failed after %s: %s ", domain.SafeString(), time.Since(startTime), err)
}
if response.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("dns response code: %s", dns.RcodeToString[response.Rcode])
}
ips := make([]net.IP, 0)
for _, answ := range response.Answer {
if aRecord, ok := answ.(*dns.A); ok {
ips = append(ips, aRecord.A)
}
if aaaaRecord, ok := answ.(*dns.AAAA); ok {
ips = append(ips, aaaaRecord.AAAA)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("no A or AAAA records found for %s", domain.SafeString())
}
return ips, nil
}

View File

@@ -16,7 +16,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -51,7 +50,7 @@ type DefaultManager struct {
statusRecorder *peer.Status
wgInterface *iface.WGIface
pubKey string
notifier *notifier.Notifier
notifier *notifier
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
@@ -66,8 +65,7 @@ func NewManager(
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)
sysOps := systemops.NewSysOps(wgInterface)
dm := &DefaultManager{
ctx: mCTX,
@@ -79,7 +77,7 @@ func NewManager(
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: notifier,
notifier: newNotifier(),
}
dm.routeRefCounter = refcounter.New(
@@ -109,7 +107,7 @@ func NewManager(
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.SetInitialClientRoutes(cr)
dm.notifier.setInitialClientRoutes(cr)
}
return dm
}
@@ -188,7 +186,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
@@ -201,14 +199,14 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
}
}
// SetRouteChangeListener set RouteListener for route change Notifier
// SetRouteChangeListener set RouteListener for route change notifier
func (m *DefaultManager) SetRouteChangeListener(listener listener.NetworkChangeListener) {
m.notifier.SetListener(listener)
m.notifier.setListener(listener)
}
// InitialRouteRange return the list of initial routes. It used by mobile systems
func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.GetInitialRouteRanges()
return m.notifier.getInitialRouteRanges()
}
// GetRouteSelector returns the route selector
@@ -228,7 +226,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
networks = m.routeSelector.FilterSelected(networks)
m.notifier.OnNewRoutes(networks)
m.notifier.onNewRoutes(networks)
m.stopObsoleteClients(networks)

View File

@@ -1,7 +1,6 @@
package notifier
package routemanager
import (
"net/netip"
"runtime"
"sort"
"strings"
@@ -11,7 +10,7 @@ import (
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
type notifier struct {
initialRouteRanges []string
routeRanges []string
@@ -19,17 +18,17 @@ type Notifier struct {
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
func newNotifier() *notifier {
return &notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *notifier) setListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
func (n *notifier) setInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
nets = append(nets, r.Network.String())
@@ -38,10 +37,7 @@ func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
n.initialRouteRanges = nets
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
func (n *notifier) onNewRoutes(idMap route.HAMap) {
newNets := make([]string, 0)
for _, routes := range idMap {
for _, r := range routes {
@@ -66,30 +62,7 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
n.notify()
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
switch runtime.GOOS {
case "android":
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
func (n *notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
@@ -101,7 +74,7 @@ func (n *Notifier) notify() {
}(n.listener)
}
func (n *Notifier) hasDiff(a []string, b []string) bool {
func (n *notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
@@ -113,7 +86,7 @@ func (n *Notifier) hasDiff(a []string, b []string) bool {
return false
}
func (n *Notifier) GetInitialRouteRanges() []string {
func (n *notifier) getInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}

View File

@@ -3,9 +3,7 @@ package systemops
import (
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
@@ -20,19 +18,10 @@ type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
prefixes map[netip.Prefix]struct{}
//nolint
mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier
}
func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}

View File

@@ -22,7 +22,7 @@ type Route struct {
Interface *net.Interface
}
func GetRoutesFromTable() ([]netip.Prefix, error) {
func getRoutesFromTable() ([]netip.Prefix, error) {
tab, err := retryFetchRIB()
if err != nil {
return nil, fmt.Errorf("fetch RIB: %v", err)

View File

@@ -36,7 +36,7 @@ func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil, nil)
r := NewSysOps(nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {

View File

@@ -50,7 +50,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical, but also we should not track and try to remove the routes either.
// These errors are not critical but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
@@ -135,11 +135,6 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
@@ -172,36 +167,6 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
return exitNextHop, nil
}
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
localInterfaces, err := net.Interfaces()
if err != nil {
log.Errorf("Failed to get local interfaces: %v", err)
return false, nil
}
for _, intf := range localInterfaces {
addrs, err := intf.Addrs()
if err != nil {
log.Errorf("Failed to get addresses for interface %s: %v", intf.Name, err)
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
log.Errorf("Failed to convert address to IPNet: %v", addr)
continue
}
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
}
}
return false, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
@@ -427,7 +392,7 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
@@ -440,7 +405,7 @@ func existsInRouteTable(prefix netip.Prefix) (bool, error) {
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}

View File

@@ -68,7 +68,7 @@ func TestAddRemoveRoutes(t *testing.T) {
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface, nil)
r := NewSysOps(wgInterface)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err)
@@ -224,7 +224,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface, nil)
r := NewSysOps(wgInterface)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
@@ -379,7 +379,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close())
})
r := NewSysOps(wgInterface, nil)
r := NewSysOps(wgInterface)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {

View File

@@ -1,64 +0,0 @@
//go:build ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
r.notify()
return nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes[prefix] = struct{}{}
r.notify()
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.prefixes, prefix)
r.notify()
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}
func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) {
return false, netip.Prefix{}
}
func (r *SysOps) notify() {
prefixes := make([]netip.Prefix, 0, len(r.prefixes))
for prefix := range r.prefixes {
prefixes = append(prefixes, prefix)
}
r.notifier.OnNewPrefixes(prefixes)
}

View File

@@ -206,7 +206,7 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error
return nil
}
func GetRoutesFromTable() ([]netip.Prefix, error) {
func getRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
@@ -504,7 +504,7 @@ func getAddressFamily(prefix netip.Prefix) int {
func hasSeparateRouting() ([]netip.Prefix, error) {
if isLegacy() {
return GetRoutesFromTable()
return getRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}

View File

@@ -1,4 +1,4 @@
//go:build android
//go:build ios || android
package systemops

View File

@@ -24,5 +24,5 @@ func EnableIPForwarding() error {
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
return getRoutesFromTable()
}

View File

@@ -94,7 +94,7 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
return nil
}
func GetRoutesFromTable() ([]netip.Prefix, error) {
func getRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()

View File

@@ -73,7 +73,7 @@ var testCases = []testCase{
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "127.0.0.1",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
@@ -110,7 +110,7 @@ var testCases = []testCase{
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "127.0.0.1",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
@@ -181,6 +181,31 @@ func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOut
return combinedOutput
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
@@ -206,6 +231,30 @@ func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
}
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput
@@ -236,25 +285,5 @@ func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
addDummyRoute(t, "10.0.0.0/8")
}
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
}
})
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
}

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -48,7 +47,6 @@ type CustomLogger interface {
type selectRoute struct {
NetID string
Network netip.Prefix
Domains domain.List
Selected bool
}
@@ -271,14 +269,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
}
routesMap := engine.GetClientRoutesWithNetID()
routeManager := engine.GetRouteManager()
if routeManager == nil {
return nil, fmt.Errorf("could not get route manager")
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
return nil, fmt.Errorf("could not get route selector")
}
routeSelector := engine.GetRouteManager().GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
@@ -288,7 +279,6 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
route := &selectRoute{
NetID: string(id),
Network: rt[0].Network,
Domains: rt[0].Domains,
Selected: routeSelector.IsSelected(id),
}
routes = append(routes, route)
@@ -309,40 +299,17 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return iPrefix < jPrefix
})
resolvedDomains := c.recorder.GetResolvedDomainsStates()
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
for _, d := range r.Domains {
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
}
domainList = append(domainList, domainResp)
}
domainDetails := DomainDetails{items: domainList}
routeSelection = append(routeSelection, RoutesSelectionInfo{
ID: r.NetID,
Network: r.Network.String(),
Domains: &domainDetails,
Selected: r.Selected,
})
}
routeSelectionDetails := RoutesSelectionDetails{items: routeSelection}
return &routeSelectionDetails
return &routeSelectionDetails, nil
}
func (c *Client) SelectRoute(id string) error {

View File

@@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
supportsSSO = false
err = nil

View File

@@ -16,25 +16,9 @@ type RoutesSelectionDetails struct {
type RoutesSelectionInfo struct {
ID string
Network string
Domains *DomainDetails
Selected bool
}
type DomainCollection interface {
Add(s DomainInfo) DomainCollection
Get(i int) *DomainInfo
Size() int
}
type DomainDetails struct {
items []DomainInfo
}
type DomainInfo struct {
Domain string
ResolvedIPs string
}
// Add new PeerInfo to the collection
func (array RoutesSelectionDetails) Add(s RoutesSelectionInfo) RoutesSelectionDetails {
array.items = append(array.items, s)
@@ -50,16 +34,3 @@ func (array RoutesSelectionDetails) Get(i int) *RoutesSelectionInfo {
func (array RoutesSelectionDetails) Size() int {
return len(array.items)
}
func (array DomainDetails) Add(s DomainInfo) DomainCollection {
array.items = append(array.items, s)
return array
}
func (array DomainDetails) Get(i int) *DomainInfo {
return &array.items[i]
}
func (array DomainDetails) Size() int {
return len(array.items)
}

View File

@@ -1828,9 +1828,8 @@ type DebugBundleRequest struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
Anonymize bool `protobuf:"varint,1,opt,name=anonymize,proto3" json:"anonymize,omitempty"`
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
}
func (x *DebugBundleRequest) Reset() {
@@ -1879,13 +1878,6 @@ func (x *DebugBundleRequest) GetStatus() string {
return ""
}
func (x *DebugBundleRequest) GetSystemInfo() bool {
if x != nil {
return x.SystemInfo
}
return false
}
type DebugBundleResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -2378,13 +2370,11 @@ var file_daemon_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c,
0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a,
0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x02, 0x38, 0x01, 0x22, 0x4a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20,
0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22,
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,

View File

@@ -263,7 +263,6 @@ message Route {
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;
bool systemInfo = 3;
}
message DebugBundleResponse {

View File

@@ -1,5 +1,3 @@
//go:build !android && !ios
package server
import (
@@ -8,70 +6,16 @@ import (
"context"
"fmt"
"io"
"net"
"net/netip"
"os"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/proto"
)
const readmeContent = `Netbird debug bundle
This debug bundle contains the following files:
status.txt: Anonymized status information of the NetBird client.
client.log: Most recent, anonymized log file of the NetBird client.
routes.txt: Anonymized system routes, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
config.txt: Anonymized configuration information of the NetBird client.
Anonymization Process
The files in this bundle have been anonymized to protect sensitive information. Here's how the anonymization was applied:
IP Addresses
IPv4 addresses are replaced with addresses starting from 192.51.100.0
IPv6 addresses are replaced with addresses starting from 100::
IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).
Reoccuring IP addresses are replaced with the same anonymized address.
Note: The anonymized IP addresses in the status file do not match those in the log and routes files. However, the anonymized IP addresses are consistent within the status file and across the routes and log files.
Domains
All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle.
Reoccuring domain names are replaced with the same anonymized domain.
Routes
For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct.
Network Interfaces
The interfaces.txt file contains information about network interfaces, including:
- Interface name
- Interface index
- MTU (Maximum Transmission Unit)
- Flags
- IP addresses associated with each interface
The IP addresses in the interfaces file are anonymized using the same process as described above. Interface names, indexes, MTUs, and flags are not anonymized.
Configuration
The config.txt file contains anonymized configuration information of the NetBird client. Sensitive information such as private keys and SSH keys are excluded. The following fields are anonymized:
- ManagementURL
- AdminURL
- NATExternalIPs
- CustomDNSAddress
Other non-sensitive configuration options are included without anonymization.
`
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -86,211 +30,93 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return nil, fmt.Errorf("create zip file: %w", err)
}
defer func() {
if closeErr := bundlePath.Close(); closeErr != nil && err == nil {
err = fmt.Errorf("close zip file: %w", closeErr)
if err := bundlePath.Close(); err != nil {
log.Errorf("failed to close zip file: %v", err)
}
if err != nil {
if removeErr := os.Remove(bundlePath.Name()); removeErr != nil {
log.Errorf("Failed to remove zip file: %v", removeErr)
if err2 := os.Remove(bundlePath.Name()); err2 != nil {
log.Errorf("Failed to remove zip file: %v", err2)
}
}
}()
if err := s.createArchive(bundlePath, req); err != nil {
return nil, err
archive := zip.NewWriter(bundlePath)
defer func() {
if err := archive.Close(); err != nil {
log.Errorf("failed to close archive writer: %v", err)
}
}()
if status := req.GetStatus(); status != "" {
filename := "status.txt"
if req.GetAnonymize() {
filename = "status.anon.txt"
}
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, filename); err != nil {
return nil, fmt.Errorf("add status file to zip: %w", err)
}
}
logFile, err := os.Open(s.logFile)
if err != nil {
return nil, fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("failed to close original log file: %v", err)
}
}()
filename := "client.log.txt"
var logReader io.Reader
errChan := make(chan error, 1)
if req.GetAnonymize() {
filename = "client.anon.log.txt"
var writer io.WriteCloser
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, errChan)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, filename); err != nil {
return nil, fmt.Errorf("add log file to zip: %w", err)
}
select {
case err := <-errChan:
if err != nil {
return nil, err
}
default:
}
return &proto.DebugBundleResponse{Path: bundlePath.Name()}, nil
}
func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleRequest) error {
archive := zip.NewWriter(bundlePath)
if err := s.addReadme(req, archive); err != nil {
return fmt.Errorf("add readme: %w", err)
}
if err := s.addStatus(req, archive); err != nil {
return fmt.Errorf("add status: %w", err)
}
func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan<- error) {
scanner := bufio.NewScanner(reader)
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
status := s.statusRecorder.GetFullStatus()
seedFromStatus(anonymizer, &status)
if err := s.addConfig(req, anonymizer, archive); err != nil {
return fmt.Errorf("add config: %w", err)
}
if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
return fmt.Errorf("add routes: %w", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
return fmt.Errorf("add interfaces: %w", err)
}
}
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
if err := archive.Close(); err != nil {
return fmt.Errorf("close archive writer: %w", err)
}
return nil
}
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if req.GetAnonymize() {
readmeReader := strings.NewReader(readmeContent)
if err := addFileToZip(archive, readmeReader, "README.txt"); err != nil {
return fmt.Errorf("add README file to zip: %w", err)
}
}
return nil
}
func (s *Server) addStatus(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if status := req.GetStatus(); status != "" {
statusReader := strings.NewReader(status)
if err := addFileToZip(archive, statusReader, "status.txt"); err != nil {
return fmt.Errorf("add status file to zip: %w", err)
}
}
return nil
}
func (s *Server) addConfig(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
var configContent strings.Builder
s.addCommonConfigFields(&configContent)
if req.GetAnonymize() {
if s.config.ManagementURL != nil {
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", anonymizer.AnonymizeURI(s.config.ManagementURL.String())))
}
if s.config.AdminURL != nil {
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", anonymizer.AnonymizeURI(s.config.AdminURL.String())))
}
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", anonymizeNATExternalIPs(s.config.NATExternalIPs, anonymizer)))
if s.config.CustomDNSAddress != "" {
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", anonymizer.AnonymizeString(s.config.CustomDNSAddress)))
}
} else {
if s.config.ManagementURL != nil {
configContent.WriteString(fmt.Sprintf("ManagementURL: %s\n", s.config.ManagementURL.String()))
}
if s.config.AdminURL != nil {
configContent.WriteString(fmt.Sprintf("AdminURL: %s\n", s.config.AdminURL.String()))
}
configContent.WriteString(fmt.Sprintf("NATExternalIPs: %v\n", s.config.NATExternalIPs))
if s.config.CustomDNSAddress != "" {
configContent.WriteString(fmt.Sprintf("CustomDNSAddress: %s\n", s.config.CustomDNSAddress))
}
}
// Add config content to zip file
configReader := strings.NewReader(configContent.String())
if err := addFileToZip(archive, configReader, "config.txt"); err != nil {
return fmt.Errorf("add config file to zip: %w", err)
}
return nil
}
func (s *Server) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
// Add non-sensitive fields
configContent.WriteString(fmt.Sprintf("WgIface: %s\n", s.config.WgIface))
configContent.WriteString(fmt.Sprintf("WgPort: %d\n", s.config.WgPort))
if s.config.NetworkMonitor != nil {
configContent.WriteString(fmt.Sprintf("NetworkMonitor: %v\n", *s.config.NetworkMonitor))
}
configContent.WriteString(fmt.Sprintf("IFaceBlackList: %v\n", s.config.IFaceBlackList))
configContent.WriteString(fmt.Sprintf("DisableIPv6Discovery: %v\n", s.config.DisableIPv6Discovery))
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", s.config.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", s.config.RosenpassPermissive))
if s.config.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *s.config.ServerSSHAllowed))
}
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", s.config.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", s.config.DNSRouteInterval))
}
func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
if routes, err := systemops.GetRoutesFromTable(); err != nil {
log.Errorf("Failed to get routes: %v", err)
} else {
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer)
routesReader := strings.NewReader(routesContent)
if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
}
return nil
}
func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
interfaces, err := net.Interfaces()
if err != nil {
return fmt.Errorf("get interfaces: %w", err)
}
interfacesContent := formatInterfaces(interfaces, req.GetAnonymize(), anonymizer)
interfacesReader := strings.NewReader(interfacesContent)
if err := addFileToZip(archive, interfacesReader, "interfaces.txt"); err != nil {
return fmt.Errorf("add interfaces file to zip: %w", err)
}
return nil
}
func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) {
logFile, err := os.Open(s.logFile)
if err != nil {
return fmt.Errorf("open log file: %w", err)
}
defer func() {
if err := logFile.Close(); err != nil {
log.Errorf("Failed to close original log file: %v", err)
if err := writer.Close(); err != nil {
log.Errorf("Failed to close writer: %v", err)
}
}()
var logReader io.Reader
if req.GetAnonymize() {
var writer *io.PipeWriter
logReader, writer = io.Pipe()
go s.anonymize(logFile, writer, anonymizer)
} else {
logReader = logFile
}
if err := addFileToZip(archive, logReader, "client.log"); err != nil {
return fmt.Errorf("add log file to zip: %w", err)
}
return nil
}
func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) {
defer func() {
// always nil
_ = writer.Close()
}()
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := anonymizer.AnonymizeString(scanner.Text())
if _, err := writer.Write([]byte(line + "\n")); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize write: %w", err))
errChan <- fmt.Errorf("write line to writer: %w", err)
return
}
}
if err := scanner.Err(); err != nil {
writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err))
errChan <- fmt.Errorf("read line from scanner: %w", err)
return
}
}
@@ -315,22 +141,8 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error {
header := &zip.FileHeader{
Name: filename,
Method: zip.Deflate,
Modified: time.Now(),
CreatorVersion: 20, // Version 2.0
ReaderVersion: 20, // Version 2.0
Flags: 0x800, // UTF-8 filename
}
// If the reader is a file, we can get more accurate information
if f, ok := reader.(*os.File); ok {
if stat, err := f.Stat(); err != nil {
log.Tracef("Failed to get file stat for %s: %v", filename, err)
} else {
header.Modified = stat.ModTime()
}
Name: filename,
Method: zip.Deflate,
}
writer, err := archive.CreateHeader(header)
@@ -353,13 +165,6 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
for _, peer := range status.Peers {
a.AnonymizeDomain(peer.FQDN)
for route := range peer.GetRoutes() {
a.AnonymizeRoute(route)
}
}
for route := range status.LocalPeerState.Routes {
a.AnonymizeRoute(route)
}
for _, nsGroup := range status.NSGroupStates {
@@ -374,113 +179,3 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
}
}
func formatRoutes(routes []netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) string {
var ipv4Routes, ipv6Routes []netip.Prefix
// Separate IPv4 and IPv6 routes
for _, route := range routes {
if route.Addr().Is4() {
ipv4Routes = append(ipv4Routes, route)
} else {
ipv6Routes = append(ipv6Routes, route)
}
}
// Sort IPv4 and IPv6 routes separately
sort.Slice(ipv4Routes, func(i, j int) bool {
return ipv4Routes[i].Bits() > ipv4Routes[j].Bits()
})
sort.Slice(ipv6Routes, func(i, j int) bool {
return ipv6Routes[i].Bits() > ipv6Routes[j].Bits()
})
var builder strings.Builder
// Format IPv4 routes
builder.WriteString("IPv4 Routes:\n")
for _, route := range ipv4Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
// Format IPv6 routes
builder.WriteString("\nIPv6 Routes:\n")
for _, route := range ipv6Routes {
formatRoute(&builder, route, anonymize, anonymizer)
}
return builder.String()
}
func formatRoute(builder *strings.Builder, route netip.Prefix, anonymize bool, anonymizer *anonymize.Anonymizer) {
if anonymize {
anonymizedIP := anonymizer.AnonymizeIP(route.Addr())
builder.WriteString(fmt.Sprintf("%s/%d\n", anonymizedIP, route.Bits()))
} else {
builder.WriteString(fmt.Sprintf("%s\n", route))
}
}
func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *anonymize.Anonymizer) string {
sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Name < interfaces[j].Name
})
var builder strings.Builder
builder.WriteString("Network Interfaces:\n")
for _, iface := range interfaces {
builder.WriteString(fmt.Sprintf("\nInterface: %s\n", iface.Name))
builder.WriteString(fmt.Sprintf(" Index: %d\n", iface.Index))
builder.WriteString(fmt.Sprintf(" MTU: %d\n", iface.MTU))
builder.WriteString(fmt.Sprintf(" Flags: %v\n", iface.Flags))
addrs, err := iface.Addrs()
if err != nil {
builder.WriteString(fmt.Sprintf(" Addresses: Error retrieving addresses: %v\n", err))
} else {
builder.WriteString(" Addresses:\n")
for _, addr := range addrs {
prefix, err := netip.ParsePrefix(addr.String())
if err != nil {
builder.WriteString(fmt.Sprintf(" Error parsing address: %v\n", err))
continue
}
ip := prefix.Addr()
if anonymize {
ip = anonymizer.AnonymizeIP(ip)
}
builder.WriteString(fmt.Sprintf(" %s/%d\n", ip, prefix.Bits()))
}
}
}
return builder.String()
}
func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string {
anonymizedIPs := make([]string, len(ips))
for i, ip := range ips {
parts := strings.SplitN(ip, "/", 2)
ip1, err := netip.ParseAddr(parts[0])
if err != nil {
anonymizedIPs[i] = ip
continue
}
ip1anon := anonymizer.AnonymizeIP(ip1)
if len(parts) == 2 {
ip2, err := netip.ParseAddr(parts[1])
if err != nil {
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, parts[1])
} else {
ip2anon := anonymizer.AnonymizeIP(ip2)
anonymizedIPs[i] = fmt.Sprintf("%s/%s", ip1anon, ip2anon)
}
} else {
anonymizedIPs[i] = ip1anon.String()
}
}
return anonymizedIPs
}

View File

@@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}
// Down engine work in the daemon.
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -593,25 +593,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)
engine := s.connectClient.Engine()
for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}
select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
return &proto.DownResponse{}, nil
}
// Status returns the daemon status

View File

@@ -19,7 +19,6 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
@@ -121,11 +120,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}

View File

@@ -1,12 +1,11 @@
package system
import (
"testing"
log "github.com/sirupsen/logrus"
"testing"
)
func Test_sysInfoMac(t *testing.T) {
func Test_sysInfo(t *testing.T) {
t.Skip("skipping darwin test")
serialNum, prodName, manufacturer := sysInfo()
if serialNum == "" {

View File

@@ -8,7 +8,6 @@ import (
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"
@@ -21,26 +20,6 @@ import (
"github.com/netbirdio/netbird/version"
)
type SysInfoGetter interface {
GetSysInfo() SysInfo
}
type SysInfoWrapper struct {
si sysinfo.SysInfo
}
func (s SysInfoWrapper) GetSysInfo() SysInfo {
s.si.GetSysInfo()
return SysInfo{
ChassisSerial: s.si.Chassis.Serial,
ProductSerial: s.si.Product.Serial,
BoardSerial: s.si.Board.Serial,
ProductName: s.si.Product.Name,
BoardName: s.si.Board.Name,
ProductVendor: s.si.Product.Vendor,
}
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
info := _getInfo()
@@ -65,8 +44,7 @@ func GetInfo(ctx context.Context) *Info {
log.Warnf("failed to discover network addresses: %s", err)
}
si := SysInfoWrapper{}
serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo())
serialNum, prodName, manufacturer := sysInfo()
env := Environment{
Cloud: detect_cloud.Detect(ctx),
@@ -108,36 +86,12 @@ func _getInfo() string {
return out.String()
}
func sysInfo(si SysInfo) (string, string, string) {
isascii := regexp.MustCompile("^[[:ascii:]]+$")
serials := []string{si.ChassisSerial, si.ProductSerial}
serial := ""
for _, s := range serials {
if isascii.MatchString(s) {
serial = s
if s != "Default string" {
break
}
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var si sysinfo.SysInfo
si.GetSysInfo()
serial := si.Chassis.Serial
if (serial == "Default string" || serial == "") && si.Product.Serial != "" {
serial = si.Product.Serial
}
if serial == "" && isascii.MatchString(si.BoardSerial) {
serial = si.BoardSerial
}
var name string
for _, n := range []string{si.ProductName, si.BoardName} {
if isascii.MatchString(n) {
name = n
break
}
}
var manufacturer string
if isascii.MatchString(si.ProductVendor) {
manufacturer = si.ProductVendor
}
return serial, name, manufacturer
return serial, si.Product.Name, si.Product.Vendor
}

View File

@@ -1,12 +0,0 @@
package system
// SysInfo used to moc out the sysinfo getter
type SysInfo struct {
ChassisSerial string
ProductSerial string
BoardSerial string
ProductName string
BoardName string
ProductVendor string
}

View File

@@ -1,198 +0,0 @@
package system
import "testing"
func Test_sysInfo(t *testing.T) {
tests := []struct {
name string
sysInfo SysInfo
wantSerialNum string
wantProdName string
wantManufacturer string
}{
{
name: "Test Case 1",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Fallback to Product Serial",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Product serial",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Product serial",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Fallback to Product Serial with default string",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "Product serial",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "Product serial",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial and Product Serial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "M80-G8013200245",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "M80-G8013200245",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Non UTF-8 in Chassis Serial and Product Serial and BoardSerial",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "\x80",
ProductName: "B650M-HDV/M.2",
BoardName: "B650M-HDV/M.2",
ProductVendor: "ASRock",
},
wantSerialNum: "",
wantProdName: "B650M-HDV/M.2",
wantManufacturer: "ASRock",
},
{
name: "Empty Product Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "",
BoardName: "boardname",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "boardname",
wantManufacturer: "ASRock",
},
{
name: "Invalid Product Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "\x80",
BoardName: "boardname",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "boardname",
wantManufacturer: "ASRock",
},
{
name: "Invalid BoardName Name",
sysInfo: SysInfo{
ChassisSerial: "Default string",
ProductSerial: "Default string",
BoardSerial: "M80-G8013200245",
ProductName: "\x80",
BoardName: "\x80",
ProductVendor: "ASRock",
},
wantSerialNum: "Default string",
wantProdName: "",
wantManufacturer: "ASRock",
},
{
name: "Invalid chars",
sysInfo: SysInfo{
ChassisSerial: "\x80",
ProductSerial: "\x80",
BoardSerial: "\x80",
ProductName: "\x80",
BoardName: "\x80",
ProductVendor: "\x80",
},
wantSerialNum: "",
wantProdName: "",
wantManufacturer: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo)
if gotSerialNum != tt.wantSerialNum {
t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum)
}
if gotProdName != tt.wantProdName {
t.Errorf("sysInfo() gotProdName = %v, want %v", gotProdName, tt.wantProdName)
}
if gotManufacturer != tt.wantManufacturer {
t.Errorf("sysInfo() gotManufacturer = %v, want %v", gotManufacturer, tt.wantManufacturer)
}
})
}
}

View File

@@ -15,6 +15,7 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
"unicode"
@@ -22,8 +23,8 @@ import (
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
"fyne.io/systray"
"github.com/cenkalti/backoff/v4"
"github.com/getlantern/systray"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -33,7 +34,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
@@ -62,25 +62,8 @@ func main() {
var errorMSG string
flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window")
tmpDir := "/tmp"
if runtime.GOOS == "windows" {
tmpDir = os.TempDir()
}
var saveLogsInFile bool
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", tmpDir))
flag.Parse()
if saveLogsInFile {
logFile := path.Join(tmpDir, fmt.Sprintf("netbird-ui-%d.log", os.Getpid()))
err := util.InitLog("trace", logFile)
if err != nil {
log.Errorf("error while initializing log: %v", err)
return
}
}
a := app.NewWithID("NetBird")
a.SetIcon(fyne.NewStaticResource("netbird", iconDisconnectedPNG))
@@ -93,12 +76,8 @@ func main() {
if showSettings || showRoutes {
a.Run()
} else {
running, err := isAnotherProcessRunning()
if err != nil {
log.Errorf("error while checking process: %v", err)
}
if running {
log.Warn("another process is running")
if err := checkPIDFile(); err != nil {
log.Errorf("check PID file: %v", err)
return
}
client.setDefaultFonts()
@@ -882,3 +861,104 @@ func openURL(url string) error {
}
return err
}
// checkPIDFile exists and return error, or write new.
func checkPIDFile() error {
pidFile := path.Join(os.TempDir(), "wiretrustee-ui.pid")
if piddata, err := os.ReadFile(pidFile); err == nil {
if pid, err := strconv.Atoi(string(piddata)); err == nil {
if process, err := os.FindProcess(pid); err == nil {
if err := process.Signal(syscall.Signal(0)); err == nil {
return fmt.Errorf("process already exists: %d", pid)
}
}
}
}
return os.WriteFile(pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0o664) //nolint:gosec
}
func (s *serviceClient) setDefaultFonts() {
var (
defaultFontPath string
)
//TODO: Linux Multiple Language Support
switch runtime.GOOS {
case "darwin":
defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
case "windows":
fontPath := s.getWindowsFontFilePath()
defaultFontPath = fontPath
}
_, err := os.Stat(defaultFontPath)
if err == nil {
os.Setenv("FYNE_FONT", defaultFontPath)
}
}
func (s *serviceClient) getWindowsFontFilePath() (fontPath string) {
/*
https://learn.microsoft.com/en-us/windows/apps/design/globalizing/loc-international-fonts
https://learn.microsoft.com/en-us/typography/fonts/windows_11_font_list
*/
var (
fontFolder string = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
cmd := exec.Command("powershell", "-Command", "(Get-Culture).Name")
output, err := cmd.Output()
if err != nil {
log.Errorf("Failed to get Windows default language setting: %v", err)
fontPath = path.Join(fontFolder, fontMapping["default"])
return
}
defaultLanguage := strings.TrimSpace(string(output))
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
fontPath = path.Join(fontFolder, fontMapping["nirmala"])
return
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
fontPath = path.Join(fontFolder, font)
} else {
fontPath = path.Join(fontFolder, fontMapping["default"])
}
return
}

View File

@@ -1,26 +0,0 @@
//go:build darwin
package main
import (
"os"
"runtime"
log "github.com/sirupsen/logrus"
)
const defaultFontPath = "/Library/Fonts/Arial Unicode.ttf"
func (s *serviceClient) setDefaultFonts() {
// TODO: add other bsd paths
if runtime.GOOS != "darwin" {
return
}
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}

View File

@@ -1,7 +0,0 @@
//go:build !386
package main
func (s *serviceClient) setDefaultFonts() {
//TODO: Linux Multiple Language Support
}

View File

@@ -1,91 +0,0 @@
package main
import (
"os"
"path"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
func (s *serviceClient) setDefaultFonts() {
defaultFontPath := s.getWindowsFontFilePath()
if _, err := os.Stat(defaultFontPath); err != nil {
log.Errorf("Failed to find default font file: %v", err)
return
}
os.Setenv("FYNE_FONT", defaultFontPath)
}
func (s *serviceClient) getWindowsFontFilePath() string {
var (
fontFolder = "C:/Windows/Fonts"
fontMapping = map[string]string{
"default": "Segoeui.ttf",
"zh-CN": "Msyh.ttc",
"am-ET": "Ebrima.ttf",
"nirmala": "Nirmala.ttf",
"chr-CHER-US": "Gadugi.ttf",
"zh-HK": "Msjh.ttc",
"zh-TW": "Msjh.ttc",
"ja-JP": "Yugothm.ttc",
"km-KH": "Leelawui.ttf",
"ko-KR": "Malgun.ttf",
"th-TH": "Leelawui.ttf",
"ti-ET": "Ebrima.ttf",
}
nirMalaLang = []string{
"as-IN",
"bn-BD",
"bn-IN",
"gu-IN",
"hi-IN",
"kn-IN",
"kok-IN",
"ml-IN",
"mr-IN",
"ne-NP",
"or-IN",
"pa-IN",
"si-LK",
"ta-IN",
"te-IN",
}
)
// getUserDefaultLocaleName.Call() panics if the func is not found
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %v", r)
}
}()
kernel32 := windows.NewLazySystemDLL("kernel32.dll")
getUserDefaultLocaleName := kernel32.NewProc("GetUserDefaultLocaleName")
buf := make([]uint16, 85) // LOCALE_NAME_MAX_LENGTH is usually 85
r, _, err := getUserDefaultLocaleName.Call(uintptr(unsafe.Pointer(&buf[0])), uintptr(len(buf)))
// returns 0 on failure, err is always non-nil
// https://learn.microsoft.com/en-us/windows/win32/api/winnls/nf-winnls-getuserdefaultlocalename
if r == 0 {
log.Errorf("GetUserDefaultLocaleName call failed: %v", err)
return path.Join(fontFolder, fontMapping["default"])
}
defaultLanguage := windows.UTF16ToString(buf)
for _, lang := range nirMalaLang {
if defaultLanguage == lang {
return path.Join(fontFolder, fontMapping["nirmala"])
}
}
if font, ok := fontMapping[defaultLanguage]; ok {
return path.Join(fontFolder, font)
}
return path.Join(fontFolder, fontMapping["default"])
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.9 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

@@ -1,37 +0,0 @@
package main
import (
"os"
"path/filepath"
"strings"
"github.com/shirou/gopsutil/v3/process"
)
func isAnotherProcessRunning() (bool, error) {
processes, err := process.Processes()
if err != nil {
return false, err
}
pid := os.Getpid()
processName := strings.ToLower(filepath.Base(os.Args[0]))
for _, p := range processes {
if int(p.Pid) == pid {
continue
}
runningProcessPath, err := p.Exe()
// most errors are related to short-lived processes
if err != nil {
continue
}
if strings.Contains(strings.ToLower(runningProcessPath), processName) && isProcessOwnedByCurrentUser(p) {
return true, nil
}
}
return false, nil
}

View File

@@ -1,26 +0,0 @@
//go:build !windows
package main
import (
"os"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
currentUserID := os.Getuid()
uids, err := p.Uids()
if err != nil {
log.Errorf("get process uids: %v", err)
return false
}
for _, id := range uids {
log.Debugf("checking process uid: %d", id)
if int(id) == currentUserID {
return true
}
}
return false
}

View File

@@ -1,24 +0,0 @@
package main
import (
"os/user"
"github.com/shirou/gopsutil/v3/process"
log "github.com/sirupsen/logrus"
)
func isProcessOwnedByCurrentUser(p *process.Process) bool {
processUsername, err := p.Username()
if err != nil {
log.Errorf("get process username error: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Errorf("get current user error: %v", err)
return false
}
return processUsername == currUser.Username
}

View File

@@ -10,7 +10,7 @@ import (
func EncryptMessage(remotePubKey wgtypes.Key, ourPrivateKey wgtypes.Key, message pb.Message) ([]byte, error) {
byteResp, err := pb.Marshal(message)
if err != nil {
log.Errorf("failed marshalling message %v, %+v", err, message.String())
log.Errorf("failed marshalling message %v", err)
return nil, err
}

View File

@@ -14,29 +14,14 @@ type TextFormatter struct {
levelDesc []string
}
// SyslogFormatter formats logs into text
type SyslogFormatter struct {
levelDesc []string
}
var validLevelDesc = []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"}
// NewTextFormatter create new MyTextFormatter instance
func NewTextFormatter() *TextFormatter {
return &TextFormatter{
levelDesc: validLevelDesc,
levelDesc: []string{"PANC", "FATL", "ERRO", "WARN", "INFO", "DEBG", "TRAC"},
timestampFormat: time.RFC3339, // or RFC3339
}
}
// NewSyslogFormatter create new MySyslogFormatter instance
func NewSyslogFormatter() *SyslogFormatter {
return &SyslogFormatter{
levelDesc: validLevelDesc,
}
}
// Format renders a single log entry
func (f *TextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
@@ -64,20 +49,3 @@ func (f *TextFormatter) parseLevel(level logrus.Level) string {
return f.levelDesc[level]
}
// Format renders a single log entry
func (f *SyslogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
var fields string
keys := make([]string, 0, len(entry.Data))
for k, v := range entry.Data {
if k == "source" {
continue
}
keys = append(keys, fmt.Sprintf("%s: %v", k, v))
}
if len(keys) > 0 {
fields = fmt.Sprintf("[%s] ", strings.Join(keys, ", "))
}
return []byte(fmt.Sprintf("%s%s\n", fields, entry.Message)), nil
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLogTextFormat(t *testing.T) {
func TestLogMessageFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
@@ -24,20 +24,3 @@ func TestLogTextFormat(t *testing.T) {
expectedString := "^2021-02-21T01:10:30Z WARN \\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] some/fancy/path.go:46: Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}
func TestLogSyslogFormat(t *testing.T) {
someEntry := &logrus.Entry{
Data: logrus.Fields{"att1": 1, "att2": 2, "source": "some/fancy/path.go:46"},
Time: time.Date(2021, time.Month(2), 21, 1, 10, 30, 0, time.UTC),
Level: 3,
Message: "Some Message",
}
formatter := NewSyslogFormatter()
result, _ := formatter.Format(someEntry)
parsedString := string(result)
expectedString := "^\\[(att1: 1, att2: 2|att2: 2, att1: 1)\\] Some Message\\s+$"
assert.Regexp(t, expectedString, parsedString)
}

View File

@@ -10,12 +10,6 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetSyslogFormatter set the text formatter for given logger.
func SetSyslogFormatter(logger *logrus.Logger) {
logger.Formatter = NewSyslogFormatter()
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {

57
go.mod
View File

@@ -19,26 +19,26 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.24.0
golang.org/x/sys v0.21.0
golang.org/x/crypto v0.23.0
golang.org/x/sys v0.20.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.64.1
google.golang.org/grpc v1.64.0
google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
require (
fyne.io/fyne/v2 v2.5.0
fyne.io/systray v1.11.0
fyne.io/fyne/v2 v2.1.4
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1
github.com/fsnotify/fsnotify v1.7.0
github.com/fsnotify/fsnotify v1.6.0
github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0
@@ -82,11 +82,11 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.26.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/net v0.26.0
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.25.0
golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0
golang.org/x/term v0.21.0
golang.org/x/term v0.20.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7
@@ -100,7 +100,7 @@ require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.4.0 // indirect
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
@@ -115,29 +115,31 @@ require (
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v26.1.4+incompatible // indirect
github.com/docker/docker v26.1.3+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.0 // indirect
github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect
github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a // indirect
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-text/render v0.1.1-0.20240418202334-dd62631dae9b // indirect
github.com/go-text/typesetting v0.1.0 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -145,11 +147,9 @@ require (
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
@@ -161,10 +161,10 @@ require (
github.com/moby/sys/user v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect
@@ -176,24 +176,21 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.15.0 // indirect
github.com/rymdport/portal v0.2.6 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
github.com/yuin/goldmark v1.4.13 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
@@ -215,5 +212,3 @@ replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-2023080111
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace fyne.io/fyne/v2 => github.com/Jacalz/fyne/v2 v2.0.0-20240809153104-a1e60718db70

558
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -64,7 +64,7 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
t.wrapper = newDeviceWrapper(tunDevice)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -49,7 +49,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
device.NewLogger(device.LogLevelSilent, "[netbird] "),
)
err = t.assignAddr()

View File

@@ -64,7 +64,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.wrapper = newDeviceWrapper(tunDevice)
log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(device.LogLevelSilent, "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()

View File

@@ -54,7 +54,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
device.NewLogger(device.LogLevelSilent, "[netbird] "),
)
t.configurer = newWGUSPConfigurer(t.device, t.name)

View File

@@ -57,7 +57,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
device.NewLogger(device.LogLevelSilent, "[netbird] "),
)
err = t.assignAddr()

View File

@@ -41,7 +41,6 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
func (t *tunDevice) Create() (wgConfigurer, error) {
log.Info("create tun interface")
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, err
@@ -53,7 +52,7 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice(
t.wrapper,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
device.NewLogger(device.LogLevelSilent, "[netbird] "),
)
luid := winipcfg.LUID(t.nativeTunDevice.LUID())

View File

@@ -1,15 +0,0 @@
package iface
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@@ -9,10 +9,7 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/client/system"
@@ -74,11 +71,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
t.Fatal(err)
}

View File

@@ -2,6 +2,7 @@ package client
import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
@@ -10,11 +11,15 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
@@ -46,21 +51,26 @@ type GrpcClient struct {
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
if err != nil {
log.Errorf("failed creating connection to Management Service: %v", err)
log.Errorf("failed creating connection to Management Service %v", err)
return nil, err
}
@@ -316,44 +326,25 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
log.Errorf("failed to decrypt registration message: %s", err)
return nil, err
}

View File

@@ -190,7 +190,7 @@ var (
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}

View File

@@ -18,8 +18,6 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -39,7 +37,6 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@@ -68,12 +65,10 @@ type AccountManager interface {
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
@@ -100,9 +95,7 @@ type AccountManager interface {
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
@@ -140,8 +133,8 @@ type AccountManager interface {
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
@@ -175,8 +168,6 @@ type DefaultAccountManager struct {
userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator
metrics telemetry.AppMetrics
}
// Settings represents Account settings structure that can be modified via API and Dashboard
@@ -408,16 +399,8 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@@ -453,7 +436,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
@@ -461,7 +444,7 @@ func (a *Account) GetPeerNetworkMap(
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
nm := &NetworkMap{
return &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
@@ -469,60 +452,6 @@ func (a *Account) GetPeerNetworkMap(
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}
if metrics != nil {
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
}
return nm
}
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
var merr *multierror.Error
if dnsDomain == "" {
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
}
domainSuffix := "." + dnsDomain
var sb strings.Builder
for _, peer := range a.Peers {
if peer.DNSLabel == "" {
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
continue
}
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
sb.WriteString(peer.DNSLabel)
sb.WriteString(domainSuffix)
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: sb.String(),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
sb.Reset()
}
go func() {
if merr != nil {
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
}
}()
return customZone
}
// GetExpiredPeers returns peers that have been expired
@@ -839,6 +768,10 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
// Returns true if there are changes in the JWT group membership.
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
if len(groupsNames) == 0 {
return false
}
user, ok := a.Users[userID]
if !ok {
return false
@@ -922,7 +855,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok || group.Name == "All" {
if !ok {
continue
}
update := make([]string, 0, len(group.Peers))
@@ -940,18 +873,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(
ctx context.Context,
store Store,
peersUpdateManager *PeersUpdateManager,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
@@ -966,7 +891,6 @@ func BuildManager(
peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
}
allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
@@ -1052,7 +976,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -1103,7 +1027,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -1202,7 +1126,7 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -1662,7 +1586,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
@@ -1745,7 +1669,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
if err != nil {
return nil, nil, err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
@@ -1802,6 +1726,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Errorf("failed to save account: %v", err)
} else {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
// todo: optimize this as part of the group optimizations
am.updateAccountPeers(ctx, account)
unlock()
alreadyUnlocked = true
@@ -1901,7 +1826,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err == nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id)
defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
@@ -1921,7 +1846,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
@@ -1935,11 +1860,17 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
}
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
}
return nil, nil, nil, err
}
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
@@ -1959,20 +1890,26 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return peer, netMap, postureChecks, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return status.Errorf(status.Unauthenticated, "peer not registered")
}
return err
}
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err)
}
return nil
@@ -1985,7 +1922,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountReadLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@@ -411,8 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{}
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -2221,13 +2219,6 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
})
t.Run("remove all JWT groups", func(t *testing.T) {
updated := account.SetJWTGroups("user1", []string{})
assert.True(t, updated, "account should be updated")
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
})
}
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
@@ -2295,13 +2286,7 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
type TB interface {
Cleanup(func())
Helper()
TempDir() string
}
func createManager(t TB) (*DefaultAccountManager, error) {
func createManager(t *testing.T) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
@@ -2310,12 +2295,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
}
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, err
}
@@ -2323,7 +2303,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t TB) (Store, error) {
func createStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)

View File

@@ -56,10 +56,6 @@ type Config struct {
func (c Config) GetAuthAudiences() []string {
audiences := []string{c.HttpConfig.AuthAudience}
if c.HttpConfig.ExtraAuthAudience != "" {
audiences = append(audiences, c.HttpConfig.ExtraAuthAudience)
}
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
}
@@ -94,8 +90,6 @@ type HttpServerConfig struct {
OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
// Extra audience
ExtraAuthAudience string
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)

View File

@@ -4,8 +4,8 @@ import (
"context"
"fmt"
"strconv"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
@@ -17,50 +17,6 @@ import (
const defaultTTL = 300
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
type lookupMap map[string]struct{}
// DNSSettings defines dns settings at the account level
@@ -80,7 +36,7 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -102,7 +58,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -152,78 +108,75 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
// todo: check if before/after groups are in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
return nil
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
}
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
for _, zone := range update.CustomZones {
cacheKey := zone.Domain
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
} else {
protoZone := convertToProtoCustomZone(zone)
cache.SetCustomZone(cacheKey, protoZone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
protoZone := &proto.CustomZone{Domain: zone.Domain}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
}
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
}
return protoGroup
return customZone
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {

View File

@@ -2,14 +2,9 @@ package server
import (
"context"
"fmt"
"net/netip"
"reflect"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
@@ -200,11 +195,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
}
func createDNSStore(t *testing.T) (Store, error) {
@@ -329,150 +320,3 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return am.Store.GetAccount(context.Background(), account.Id)
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache)
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache)
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache)
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache)
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache)
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
// Verify that the cache contains elements from both configs
if _, exists := cache.GetCustomZone("example.com"); !exists {
t.Errorf("Cache should contain custom zone for example.com")
}
if _, exists := cache.GetCustomZone("example.org"); !exists {
t.Errorf("Cache should contain custom zone for example.org")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}

View File

@@ -13,7 +13,7 @@ import (
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)

View File

@@ -39,8 +39,8 @@ type FileStore struct {
mux sync.Mutex `json:"-"`
storeFile string `json:"-"`
// sync.Mutex indexed by resource ID
resourceLocks sync.Map `json:"-"`
// sync.Mutex indexed by accountID
accountLocks sync.Map `json:"-"`
globalAccountLock sync.Mutex `json:"-"`
metrics telemetry.AppMetrics `json:"-"`
@@ -281,26 +281,26 @@ func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
return s.AcquireWriteLockByUID(ctx, uniqueID)
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(ctx, accountID)
}
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
@@ -666,26 +666,6 @@ func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
return s.persist(ctx, s.storeFile)
}
// SavePeer saves the peer in the account
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return err
}
newPeer := peer.Copy()
account.Peers[peer.ID] = newPeer
s.PeerKeyID2AccountID[peer.Key] = accountID
s.PeerID2AccountID[peer.ID] = accountID
return nil
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
@@ -766,11 +746,3 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

View File

@@ -9,7 +9,6 @@ import (
"path"
"strconv"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -31,8 +30,6 @@ func loadGeolocationDatabases(dataDir string) error {
continue
}
log.Infof("geo location file %s not found , file will be downloaded", file)
switch file {
case MMDBFileName:
extractFunc := func(src string, dst string) error {

View File

@@ -2,12 +2,8 @@ package server
import (
"context"
"errors"
"fmt"
"slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -27,7 +23,7 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -54,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -81,7 +77,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -114,87 +110,64 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
var eventsToStore []func()
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// Avoid duplicate groups only for the API issued groups.
// Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
eventsToStore = append(eventsToStore, events...)
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
// coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
// todo: check if groups is in use by dns, acl, routes and before/after peers
am.updateAccountPeers(ctx, account)
for _, storeEvent := range eventsToStore {
storeEvent()
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
var eventsToStore []func()
// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if oldGroup != nil {
if exists {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
})
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
}
for _, p := range addedPeers {
@@ -203,14 +176,11 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
}
for _, p := range removedPeers {
@@ -219,17 +189,14 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
"peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
})
})
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
}
return eventsToStore
return nil
}
// difference returns the elements in `a` that aren't in `b`.
@@ -247,9 +214,9 @@ func difference(a, b []string) []string {
return diff
}
// DeleteGroup object of the peers.
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
@@ -257,14 +224,96 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return err
}
group, ok := account.Groups[groupID]
g, ok := account.Groups[groupID]
if !ok {
return nil
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
// disable a deleting integration group if the initiator is not an admin service user
if g.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userId]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
// check route links
for _, r := range account.Routes {
for _, g := range r.Groups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
for _, g := range r.PeerGroups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
}
// check DNS links
for _, dns := range account.NameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return &GroupLinkError{"name server groups", dns.Name}
}
}
}
// check ACL links
for _, policy := range account.Policies {
for _, rule := range policy.Rules {
for _, src := range rule.Sources {
if src == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
for _, dst := range rule.Destinations {
if dst == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
}
}
// check setup key links
for _, setupKey := range account.SetupKeys {
for _, grp := range setupKey.AutoGroups {
if grp == groupID {
return &GroupLinkError{"setup key", setupKey.Name}
}
}
}
// check user links
for _, user := range account.Users {
for _, grp := range user.AutoGroups {
if grp == groupID {
return &GroupLinkError{"user", user.Id}
}
}
}
// check DisabledManagementGroups
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
if disabledMgmGrp == groupID {
return &GroupLinkError{"disabled DNS management groups", g.Name}
}
}
// check integrated peer validator groups
if account.Settings.Extra != nil {
for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups {
if groupID == integratedPeerValidatorGroups {
return &GroupLinkError{"integrated validator", g.Name}
}
}
}
delete(account.Groups, groupID)
account.Network.IncSerial()
@@ -272,60 +321,17 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
// todo: check if groups is in use by dns, acl, routes and if it has peers
am.updateAccountPeers(ctx, account)
return nil
}
// DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
var allErrors error
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
continue
}
if err := validateDeleteGroup(account, group, userId); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
delete(account.Groups, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
}
am.updateAccountPeers(ctx, account)
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -343,7 +349,7 @@ func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID strin
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -372,6 +378,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
// todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
return nil
@@ -379,7 +386,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
@@ -402,106 +409,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
}
// todo: check if groups is in use by dns, acl, routes
am.updateAccountPeers(ctx, account)
return nil
}
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}

View File

@@ -3,14 +3,12 @@ package server
import (
"context"
"errors"
"fmt"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
)
const (
@@ -23,7 +21,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
t.Error("failed to create account manager")
}
_, account, err := initTestGroupAccount(am)
account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
@@ -58,7 +56,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
t.Error("failed to create account manager")
}
_, account, err := initTestGroupAccount(am)
account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
@@ -134,136 +132,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
}
}
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
assert.NoError(t, err, "Failed to init testing account")
groups := make([]*nbgroup.Group, 10)
for i := 0; i < 10; i++ {
groups[i] = &nbgroup.Group{
ID: fmt.Sprintf("group-%d", i+1),
AccountID: account.Id,
Name: fmt.Sprintf("group-%d", i+1),
Issued: nbgroup.GroupIssuedAPI,
}
}
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
name string
groupIDs []string
expectedReasons []string
expectedDeleted []string
expectedNotDeleted []string
}{
{
name: "route",
groupIDs: []string{"grp-for-route"},
expectedReasons: []string{"route"},
},
{
name: "route with peer groups",
groupIDs: []string{"grp-for-route2"},
expectedReasons: []string{"route"},
},
{
name: "name server groups",
groupIDs: []string{"grp-for-name-server-grp"},
expectedReasons: []string{"name server groups"},
},
{
name: "policy",
groupIDs: []string{"grp-for-policies"},
expectedReasons: []string{"policy"},
},
{
name: "setup keys",
groupIDs: []string{"grp-for-keys"},
expectedReasons: []string{"setup key"},
},
{
name: "users",
groupIDs: []string{"grp-for-users"},
expectedReasons: []string{"user"},
},
{
name: "integration",
groupIDs: []string{"grp-for-integration"},
expectedReasons: []string{"only service users with admin power can delete integration group"},
},
{
name: "successfully delete multiple groups",
groupIDs: []string{"group-1", "group-2"},
expectedDeleted: []string{"group-1", "group-2"},
},
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
expectedDeleted: []string{"non-existent-group"},
},
{
name: "delete multiple groups with mixed results",
groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedDeleted: []string{"group-3", "group-4"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
{
name: "delete groups with multiple errors",
groupIDs: []string{"grp-for-policies", "grp-for-users"},
expectedReasons: []string{"policy", "user"},
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int
wrappedErr, ok := err.(interface{ Unwrap() []error })
assert.Equal(t, ok, true)
for _, e := range wrappedErr.Unwrap() {
var sErr *status.Error
if errors.As(e, &sErr) {
assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message")
foundExpectedErrors++
}
var gErr *GroupLinkError
if errors.As(e, &gErr) {
assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource")
foundExpectedErrors++
}
}
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
} else {
assert.NoError(t, err)
}
for _, groupID := range tc.expectedDeleted {
_, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.Error(t, err, "group should have been deleted: %s", groupID)
}
for _, groupID := range tc.expectedNotDeleted {
group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
assert.NoError(t, err, "group should not have been deleted: %s", groupID)
assert.NotNil(t, group, "group should exist: %s", groupID)
}
})
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
accountID := "testingAcc"
domain := "example.com"
@@ -367,7 +236,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, nil, err
return nil, err
}
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
@@ -378,9 +247,5 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, nil, err
}
return am, acc, nil
return am.Store.GetAccount(context.Background(), account.Id)
}

View File

@@ -156,7 +156,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
return mapError(ctx, err)
}
@@ -179,11 +179,11 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@@ -194,12 +194,12 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, peer)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
return err
}
@@ -207,7 +207,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, peer)
return srv.Context().Err()
}
}
@@ -215,10 +215,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@@ -226,17 +226,17 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
@@ -533,46 +533,53 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
}
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
response := &proto.SyncResponse{
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers {
fqdn := rPeer.FQDN(dnsName)
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig,
RemotePeers: remotePeers,
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toProtocolChecks(ctx, checks),
}
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
})
}
return dst
}
// IsHealthy indicates whether the service is healthy
@@ -590,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@@ -526,43 +526,6 @@ components:
- revoked
- auto_groups
- usage_limit
CreateSetupKeyRequest:
type: object
properties:
name:
description: Setup Key name
type: string
example: Default key
type:
description: Setup key type, one-off for single time usage and reusable
type: string
example: reusable
expires_in:
description: Expiration time in seconds
type: integer
minimum: 86400
maximum: 31536000
example: 86400
auto_groups:
description: List of group IDs to auto-assign to peers registered with this key
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
example: 0
ephemeral:
description: Indicate that the peer will be ephemeral or not
type: boolean
example: true
required:
- name
- type
- expires_in
- auto_groups
- usage_limit
PersonalAccessToken:
type: object
properties:
@@ -1843,7 +1806,7 @@ paths:
content:
'application/json':
schema:
$ref: '#/components/schemas/CreateSetupKeyRequest'
$ref: '#/components/schemas/SetupKeyRequest'
responses:
'200':
description: A Setup Keys Object

Some files were not shown because too many files have changed in this diff Show More