Compare commits

...

70 Commits

Author SHA1 Message Date
Zoltan Papp
38f2a59d1b Add comment 2024-06-12 10:56:21 +02:00
Zoltan Papp
9504012920 Set the proper buffer size in the client code 2024-06-09 21:10:57 +02:00
Zoltan Papp
5e93d117cf Use buf pool
- eliminate reader function generation
- fix write to closed channel panic
2024-06-09 20:33:35 +02:00
Zoltan Papp
8c70b7d7ff Replace ws lib on client side 2024-06-09 12:41:52 +02:00
Zoltan Papp
ed8def4d9b Protect ws writing in Gorilla ws 2024-06-07 16:07:35 +02:00
Zoltán Papp
1e115e3893 Merge branch 'main' into feature/relay 2024-06-06 13:38:40 +02:00
Viktor Liu
deffe037aa Respect env for debug and routes sub commands (#2026) 2024-06-06 10:59:10 +02:00
Zoltan Papp
fed9e587af Add close message type 2024-06-05 19:49:30 +02:00
Zoltan Papp
983d7bafbe Remove unused variables from peer conn (#2074)
Remove unused variables from peer conn
2024-06-04 17:04:50 +02:00
Zoltan Papp
a40d4d2f32 - add comments
- avoid double closing messages
- add cleanup routine for relay manager
2024-06-04 14:40:35 +02:00
Gabriel Górski
4da29451d0 Add missing openid scope when requesting JWT token (#2089)
According to the Zitadel documentation, `openid` scope is required
when requesting JWT tokens.

Apparently Zitadel was accepting requests without it until very
recently. Now lack thereof causes 400 Bad Requests which makes it
impossible to authenticate to the Netbird dashboard.

https://zitadel.com/docs/guides/integrate/service-users/client-credentials#2-authenticating-a-service-user-and-request-a-token
2024-06-04 10:46:24 +02:00
Zoltán Papp
15818b72c6 Add alternative ws server implementation 2024-06-03 21:38:37 +02:00
Zoltán Papp
0556dc1860 Avoid nil pointer exception in test in case of err 2024-06-03 21:36:46 +02:00
Zoltán Papp
2b369cd28f Add quic transporter 2024-06-03 20:17:43 +02:00
Zoltán Papp
9d44a476c6 Fix double unlock in client.go 2024-06-03 20:14:39 +02:00
Viktor Liu
9b3449753e Ignore candidates whose IP falls into a routed network. (#2084)
This will prevent peer connections via other peers.
2024-06-03 17:31:37 +02:00
Maycon Santos
456629811b Prevent using expired ctx when sending metrics (#2088) 2024-06-03 12:41:15 +02:00
Zoltán Papp
57ddb5f262 Add comment 2024-06-03 11:22:16 +02:00
Zoltan Papp
4ced07dd8d Fix close conn threading issue 2024-06-03 01:37:56 +02:00
Zoltán Papp
3430b81622 Add relay server tracking 2024-06-01 11:48:15 +02:00
Zoltán Papp
fd4ad15c83 Move reconnection logic to separated struct 2024-06-01 11:25:00 +02:00
Zoltan Papp
c311d0d19e Fill the UI version info in system meta on Android (#2077) 2024-05-31 17:26:56 +02:00
pascal-fischer
521f7dd39f Improve login performance (#2061) 2024-05-31 16:41:12 +02:00
pascal-fischer
f9ec0a9a2e Fix PKCE auth html (#2079) 2024-05-30 17:22:58 +02:00
pascal-fischer
012235ff12 Add FindExistingPostureCheck (#2075) 2024-05-30 15:22:42 +02:00
Zoltán Papp
4ff069a102 Support multiple server 2024-05-29 16:40:26 +02:00
Zoltán Papp
7cc3964a4d Use mux for http server
Without it can not start multiple http
server instances for unit tests
2024-05-29 16:11:58 +02:00
Zoltan Papp
6d627f1923 Code cleaning 2024-05-28 01:27:53 +02:00
Zoltan Papp
076ce69a24 Add reconnect logic 2024-05-28 01:00:25 +02:00
Maycon Santos
f176807ebe Add extra logs for account not found, peer login and getAccount (#2053) 2024-05-27 12:29:28 +02:00
Maycon Santos
d4c47eaf8a Don't allow delete group from peer groups (#2055) 2024-05-27 11:06:43 +02:00
Zoltán Papp
645a1f31a7 Fix writing/reading to a closed conn 2024-05-27 10:25:08 +02:00
Zoltán Papp
b4aa7e50f9 Close sockets on server cmd 2024-05-27 09:42:27 +02:00
Bethuel Mmbaga
d35a79d3b5 Upgrade gRPC and OpenTelemetry packages for compatibility (#2003)
Upgrades `go.opentelemetry.io/otel` from version` v1.11.1` to `v1.26.0`. The upgrade addresses compatibility issues caused by the removal of several sub-packages in the latest OpenTelemetry release, which were causing broken dependencies.

**Key Changes:**
- Upgraded `go.opentelemetry.io/otel` from `v1.11.1` to `v1.26.0`.

- Fixed broken dependencies by replacing the deprecated sub-packages:
  - `go.opentelemetry.io/otel/metric/instrument`
  - `go.opentelemetry.io/otel/metric/instrument/asyncint64`
  - `go.opentelemetry.io/otel/metric/instrument/syncint64`
  
- Upgraded `google.golang.org/grpc` from `v1.56.3`  to `v1.64.0` which deprecate `Dial` and `DialContext` to `NewClient`.
2024-05-27 08:39:18 +02:00
Maycon Santos
6a2929011d Refactor firewall manager check (#2054)
Some systems don't play nice with a test chain
So we dropped the idea, and instead we check for the filter table

With this check, we might face a case where iptables is selected once and on the 
next netbird up/down it will go back to using nftables
2024-05-27 08:37:32 +02:00
Zoltán Papp
173ca25dac Fix in client the close event 2024-05-26 22:14:33 +02:00
Maycon Santos
e877c9d6c1 Update CODE_OF_CONDUCT.md (#2048) 2024-05-24 17:29:14 +02:00
Maycon Santos
7a1c96ebf4 Remove extra error mapping (#2050) 2024-05-24 14:46:11 +02:00
Zoltan Papp
41fe9f84ec Extend integrated validator with error handling (#2044) 2024-05-24 13:29:25 +02:00
Viktor Liu
d13fb0e379 Restore netbird state and log level after debug (#2047) 2024-05-24 13:27:41 +02:00
Maycon Santos
f3214527ea Use info log-level for firewall manager discover (#2045)
* Use info log-level for firewall manager discover

* Update client/firewall/create_linux.go

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2024-05-24 13:03:19 +02:00
Maycon Santos
69048bfd34 Revert "Accept any XDG_ environment variable to determine desktop (#2037)" (#2042)
This reverts commit 67e2185964.
2024-05-23 23:15:02 +02:00
Maycon Santos
29a2d93873 Log global lock acquisition per user (#2039) 2024-05-23 17:09:58 +02:00
Maycon Santos
6b01b0020e Enhance firewall manager checks to detect unsupported iptables (#2038)
Our nftables firewall manager may cause issues when rules are created using older iptable versions
2024-05-23 16:09:51 +02:00
Maycon Santos
9d3db68805 Return the proper error when a peer is deleted (#2035)
this fixes an issue causing peers to keep retrying the connection after a peer is removed from the management system
2024-05-23 14:59:09 +02:00
Maycon Santos
2e315311e0 Fix the initial daemon retry interval (#2036) 2024-05-23 14:52:52 +02:00
Zoltán Papp
36b2cd16cc Remove channel binding logic 2024-05-23 13:24:02 +02:00
Maycon Santos
67e2185964 Accept any XDG_ environment variable to determine desktop (#2037) 2024-05-23 12:34:19 +02:00
Maycon Santos
89149dc6f4 Increase the status checks timeout (#2033)
Some systems might respond with a small delay depending on various factors. Increasing the timeout to reduce the number of false-positive reports
2024-05-23 10:54:01 +02:00
Matthew R Kasun
5a1f8f13a2 use the next available port for wireguard (#2024)
check if WgPort is available, if not find the next free port
2024-05-22 18:42:56 +02:00
Viktor Liu
e71059d245 Add dummy ipv6 to macos interface (#2025) 2024-05-22 12:32:01 +02:00
Maycon Santos
91fa2e20a0 Store location information in peer event meta (#1994) 2024-05-22 12:31:16 +02:00
Zoltan Papp
61034aaf4d Gracefully conn worker shutdown (#2022)
Because the connWorker are operating with the e.peerConns list we must ensure all workers exited before we modify the content of the e.peerConns list.
If we do not do that the engine will start new connWorkers for the exists ones, and they start connection for the same peers in parallel.
2024-05-22 11:15:29 +02:00
Zoltán Papp
0a05f8b4d4 Use buffer pool and protect exported functions 2024-05-22 00:38:41 +02:00
Zoltán Papp
e82c0a55a3 Set to blocking the message queue 2024-05-21 16:21:29 +02:00
Zoltán Papp
13eb457132 Add registration response message to the communication 2024-05-21 15:51:37 +02:00
Maycon Santos
b8717b8956 Update the GUI status when daemon unavailable (#2012)
in case we got no status we mark the GUI app as disconnected
2024-05-21 15:45:49 +02:00
Zoltan Papp
1c9c9ae47e Remove sync.pool 2024-05-20 11:38:23 +02:00
Zoltan Papp
9ac5a1ed3f Add udp listener and did some change for debug purpose. 2024-05-19 12:41:06 +02:00
Zoltan Papp
d4eaec5cbd Followup messages modification 2024-05-17 23:41:47 +02:00
Zoltan Papp
6ae7a790f2 Fix buffer handling 2024-05-17 23:29:47 +02:00
Zoltan Papp
49dfbc82d9 Add relay cmd 2024-05-17 20:24:06 +02:00
Zoltan Papp
57a89cf0cc Add initial relay code 2024-05-17 17:43:28 +02:00
pascal-fischer
50201d63c2 Increase garbage collection on ios (#1981) 2024-05-17 15:58:29 +02:00
pascal-fischer
d11b39282b Enable namserver deactivation if unresponsive on iOS (#1982) 2024-05-17 12:59:46 +02:00
Viktor Liu
bd58eea8ea Refactor network monitor to wait for stop (#1992) 2024-05-17 09:43:18 +02:00
Bethuel Mmbaga
a5811a2d7d Implement experimental PostgreSQL store (#1939)
* migrate sqlite store to
 generic sql store

* fix conflicts

* init postgres store

* Add postgres store tests

* Refactor postgres store engine name

* fix tests

* Run postgres store tests on linux only

* fix tests

* Refactor

* cascade policy rules on policy deletion

* fix tests

* run postgres cases in new db

* close store connection after tests

* refactor

* using testcontainers

* sync go sum

* remove postgres service

* remove store cleanup

* go mod tidy

* remove env

* use postgres as engine and initialize test store with testcontainer

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-05-16 19:28:37 +03:00
Bethuel Mmbaga
a680f80ed9 Add installer support for Synology (#1984)
* add installer support for the synology

* skip ui installation for Synology

* Fix conflicts
2024-05-15 19:03:49 +03:00
Thorleif Jacobsen
10fbdc2c4a CentOS installations might have "apt" as "annotation processing tool", fixed so it checks for apt-get (#1955) 2024-05-15 16:33:12 +02:00
Viktor Liu
1444fbe104 Don't cancel proxy ctx on conn close (#1986) 2024-05-15 09:10:57 +02:00
111 changed files with 5498 additions and 1347 deletions

View File

@@ -15,7 +15,7 @@ jobs:
strategy:
matrix:
arch: [ '386','amd64' ]
store: [ 'jsonfile', 'sqlite' ]
store: [ 'jsonfile', 'sqlite', 'postgres']
runs-on: ubuntu-latest
steps:
- name: Install Go

View File

@@ -130,3 +130,10 @@ issues:
- path: mock\.go
linters:
- nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
identity and expression, level of experience, education, socioeconomic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.

View File

@@ -57,15 +57,17 @@ type Client struct {
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
}
// NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""),
@@ -88,6 +90,9 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
var ctx context.Context
//nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel()

View File

@@ -3,13 +3,14 @@ package cmd
import (
"context"
"fmt"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
)
var debugCmd = &cobra.Command{
@@ -58,7 +59,7 @@ var forCmd = &cobra.Command{
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
conn, err := getClient(cmd)
if err != nil {
return err
}
@@ -79,14 +80,14 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
level := parseLogLevel(args[0])
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
@@ -102,34 +103,13 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
return nil
}
func parseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd.Context())
conn, err := getClient(cmd)
if err != nil {
return err
}
@@ -137,18 +117,33 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
}
cmd.Println("Log level set to trace.")
time.Sleep(1 * time.Second)
@@ -175,10 +170,22 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}
cmd.Println("Netbird down")
// TODO reset log level
time.Sleep(1 * time.Second)
if restoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println("Creating debug bundle...")
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{

View File

@@ -2,9 +2,10 @@ package cmd
import (
"context"
"github.com/netbirdio/netbird/util"
"time"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"

View File

@@ -353,8 +353,11 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true
}
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+

View File

@@ -49,7 +49,7 @@ func init() {
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd.Context())
conn, err := getClient(cmd)
if err != nil {
return err
}
@@ -79,7 +79,7 @@ func routesList(cmd *cobra.Command, _ []string) error {
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
conn, err := getClient(cmd)
if err != nil {
return err
}
@@ -106,7 +106,7 @@ func routesSelect(cmd *cobra.Command, args []string) error {
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd.Context())
conn, err := getClient(cmd)
if err != nil {
return err
}

View File

@@ -14,6 +14,7 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto"
@@ -69,10 +70,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}

View File

@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
switch check() {
case IPTABLES:
log.Debug("creating an iptables firewall manager")
log.Info("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Debug("creating an nftables firewall manager")
log.Info("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
}
if iface.IsUserspaceBind() {
@@ -85,16 +85,58 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
useIPTABLES := false
var iptablesChains []string
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil && isIptablesClientAvailable(ip) {
major, minor, _ := ip.GetIptablesVersion()
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
if major < 1 || (major == 1 && minor < 8) {
return IPTABLES
}
useIPTABLES = true
iptablesChains, err = ip.ListChains("filter")
if err != nil {
log.Errorf("failed to list iptables chains: %s", err)
useIPTABLES = false
}
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES
}
// search for chains where table is filter
// if we find one, we assume that nftables manager can be used with iptables
for _, chain := range chains {
if chain.Table.Name == "filter" {
return NFTABLES
}
}
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
}
if isIptablesClientAvailable(ip) {
if useIPTABLES {
return IPTABLES
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"runtime"
"runtime/debug"
"strings"
@@ -91,6 +92,9 @@ func (c *ConnectClient) RunOniOS(
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
@@ -327,6 +331,15 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
engineConf.PreSharedKey = &preSharedKey
}
port, err := freePort(config.WgPort)
if err != nil {
return nil, err
}
if port != config.WgPort {
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
}
engineConf.WgPort = port
return engineConf, nil
}
@@ -376,3 +389,20 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
notifier, _ := sri.(signal.ConnStateNotifier)
return notifier
}
func freePort(start int) (int, error) {
addr := net.UDPAddr{}
if start == 0 {
start = iface.DefaultWgPort
}
for x := start; x <= 65535; x++ {
addr.Port = x
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
continue
}
conn.Close()
return x, nil
}
return 0, errors.New("no free ports")
}

View File

@@ -0,0 +1,57 @@
package internal
import (
"net"
"testing"
)
func Test_freePort(t *testing.T) {
tests := []struct {
name string
port int
want int
wantErr bool
}{
{
name: "available",
port: 51820,
want: 51820,
wantErr: false,
},
{
name: "notavailable",
port: 51830,
want: 51831,
wantErr: false,
},
{
name: "noports",
port: 65535,
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil {
t.Errorf("freePort error = %v", err)
}
c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
if err != nil {
t.Errorf("freePort error = %v", err)
}
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
if (err != nil) != tt.wantErr {
t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("freePort() = %v, want %v", got, tt.want)
}
})
c1.Close()
c2.Close()
}
}

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
@@ -260,13 +259,10 @@ func (u *upstreamResolverBase) disable(err error) {
return
}
// todo test the deactivation logic, it seems to affect the client
if runtime.GOOS != "ios" {
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) testNameserver(server string) error {

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"maps"
"math/rand"
"net"
"net/netip"
@@ -117,7 +118,8 @@ type Engine struct {
TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context
clientCancel context.CancelFunc
@@ -133,7 +135,7 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
networkWatcher *networkmonitor.NetworkWatcher
networkMonitor *networkmonitor.NetworkMonitor
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
@@ -150,6 +152,8 @@ type Engine struct {
signalProbe *Probe
relayProbe *Probe
wgProbe *Probe
wgConnWorker sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -212,7 +216,6 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
networkWatcher: networkmonitor.New(),
mgmProbe: mgmProbe,
signalProbe: signalProbe,
relayProbe: relayProbe,
@@ -229,20 +232,26 @@ func (e *Engine) Stop() error {
}
// stopping network monitor first to avoid starting the engine again
e.networkWatcher.Stop()
if e.networkMonitor != nil {
e.networkMonitor.Stop()
}
log.Info("Network monitor: stopped")
err := e.removeAllPeers()
if err != nil {
return err
}
e.clientRoutesMu.Lock()
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil
}
@@ -259,7 +268,7 @@ func (e *Engine) Start() error {
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.wgProxyFactory = wgproxy.NewFactory(e.clientCtx, e.config.WgPort)
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
wgIface, err := e.newWgIface()
if err != nil {
@@ -344,20 +353,8 @@ func (e *Engine) Start() error {
e.receiveManagementEvents()
e.receiveProbeEvents()
if e.config.NetworkMonitor {
// starting network monitor at the very last to avoid disruptions
go e.networkWatcher.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
})
} else {
log.Infof("Network monitor is disabled, not starting")
}
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
return nil
}
@@ -745,7 +742,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
@@ -879,18 +878,25 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
e.wgConnWorker.Add(1)
go e.connWorker(conn, peerKey)
}
return nil
}
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
defer e.wgConnWorker.Done()
for {
// randomize starting time a bit
min := 500
max := 2000
time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond
select {
case <-e.ctx.Done():
return
case <-time.After(duration):
}
// if peer has been removed -> give up
if !e.peerExists(peerKey) {
@@ -977,7 +983,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
NATExternalIPs: e.parseNATExternalIPMappings(),
UserspaceBind: e.wgInterface.IsUserspaceBind(),
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
}
@@ -1040,8 +1045,6 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1064,8 +1067,6 @@ func (e *Engine) receiveSignalEvents() {
return err
}
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
@@ -1088,7 +1089,8 @@ func (e *Engine) receiveSignalEvents() {
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
return err
}
conn.OnRemoteCandidate(candidate)
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE:
}
@@ -1282,11 +1284,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap {
return e.clientRoutes
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
routes[id.NetID()] = v
@@ -1399,3 +1407,26 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
}
func (e *Engine) startNetworkMonitor() {
if !e.config.NetworkMonitor {
log.Infof("Network monitor is disabled, not starting")
return
}
e.networkMonitor = networkmonitor.New()
go func() {
err := e.networkMonitor.Start(e.ctx, func() {
log.Infof("Network monitor detected network change, restarting engine")
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
if err := e.Start(); err != nil {
log.Errorf("Failed to start engine: %v", err)
}
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
}
}()
}

View File

@@ -229,6 +229,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx
type testCase struct {
name string
@@ -408,6 +409,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -566,6 +568,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -735,6 +738,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
@@ -1003,7 +1008,9 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
e.ctx = ctx
return e, err
}
func startSignal() (*grpc.Server, string, error) {
@@ -1042,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStoreFromJson(config.Datadir, nil)
store, _, err := server.NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}

View File

@@ -2,14 +2,20 @@ package networkmonitor
import (
"context"
"errors"
"sync"
)
// NetworkWatcher watches for changes in network configuration.
type NetworkWatcher struct {
var ErrStopped = errors.New("monitor has been stopped")
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.Mutex
}
// New creates a new network monitor.
func New() *NetworkWatcher {
return &NetworkWatcher{}
func New() *NetworkMonitor {
return &NetworkMonitor{}
}

View File

@@ -31,7 +31,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
for {
select {
case <-ctx.Done():
return ctx.Err()
return ErrStopped
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
@@ -63,7 +63,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
}
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
callback()
go callback()
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
@@ -84,11 +84,11 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
callback()
go callback()
case unix.RTM_DELETE:
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
callback()
go callback()
}
}
}

View File

@@ -5,6 +5,7 @@ package networkmonitor
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime/debug"
@@ -15,20 +16,18 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
)
// Start begins watching for network changes and calls the callback function and stops when a change is detected.
func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
if nw.cancel != nil {
log.Warn("Network monitor: already running, stopping previous watcher")
nw.Stop()
}
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
log.Info("Network monitor: not starting, context is already cancelled")
return
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
defer nw.Stop()
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 netip.Addr
var intf4, intf6 *net.Interface
@@ -56,27 +55,30 @@ func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
log.Errorf("Network monitor: failed to get default next hops: %v", err)
return
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
}
}()
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("Network monitor: failed to start: %v", err)
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkWatcher) Stop() {
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.cancel = nil
log.Info("Network monitor: stopped")
nw.wg.Wait()
}
}

View File

@@ -36,7 +36,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
for {
select {
case <-ctx.Done():
return ctx.Err()
return ErrStopped
// handle interface state changes
case update := <-linkChan:
@@ -47,12 +47,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
switch update.Header.Type {
case syscall.RTM_DELLINK:
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
callback()
go callback()
return nil
case syscall.RTM_NEWLINK:
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
callback()
go callback()
return nil
}
}
@@ -67,12 +67,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
// triggered on added/replaced routes
case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
callback()
go callback()
return nil
case syscall.RTM_DELROUTE:
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
callback()
go callback()
return nil
}
}

View File

@@ -4,8 +4,9 @@ package networkmonitor
import "context"
func (nw *NetworkWatcher) Start(context.Context, func()) {
func (nw *NetworkMonitor) Start(context.Context, func()) error {
return nil
}
func (nw *NetworkWatcher) Stop() {
func (nw *NetworkMonitor) Stop() {
}

View File

@@ -48,10 +48,10 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
for {
select {
case <-ctx.Done():
return ctx.Err()
return ErrStopped
case <-ticker.C:
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
callback()
go callback()
return nil
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
@@ -18,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/iface/bind"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/route"
sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
@@ -68,9 +69,6 @@ type ConnConfig struct {
NATExternalIPs []string
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
UserspaceBind bool
// RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
@@ -133,32 +131,15 @@ type Conn struct {
wgProxyFactory *wgproxy.Factory
wgProxy wgproxy.Proxy
remoteModeCh chan ModeMessage
meta meta
adapter iface.TunAdapter
iFaceDiscover stdnet.ExternalIFaceDiscover
sentExtraSrflx bool
remoteEndpoint *net.UDPAddr
remoteConn *ice.Conn
connID nbnet.ConnectionID
beforeAddPeerHooks []BeforeAddPeerHookFunc
afterRemovePeerHooks []AfterRemovePeerHookFunc
}
// meta holds meta information about a connection
type meta struct {
protoSupport signal.FeaturesSupport
}
// ModeMessage represents a connection mode chosen by the peer
type ModeMessage struct {
// Direct indicates that it decided to use a direct connection
Direct bool
}
// GetConf returns the connection config
func (conn *Conn) GetConf() ConnConfig {
return conn.config
@@ -185,7 +166,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
statusRecorder: statusRecorder,
remoteModeCh: make(chan ModeMessage, 1),
wgProxyFactory: wgProxyFactory,
adapter: adapter,
iFaceDiscover: iFaceDiscover,
@@ -353,7 +333,7 @@ func (conn *Conn) Open(ctx context.Context) error {
err = conn.agent.GatherCandidates()
if err != nil {
return err
return fmt.Errorf("gather candidates: %v", err)
}
// will block until connection succeeded
@@ -370,14 +350,12 @@ func (conn *Conn) Open(ctx context.Context) error {
return err
}
// dynamically set remote WireGuard port is other side specified a different one from the default one
// dynamically set remote WireGuard port if other side specified a different one from the default one
remoteWgPort := iface.DefaultWgPort
if remoteOfferAnswer.WgListenPort != 0 {
remoteWgPort = remoteOfferAnswer.WgListenPort
}
conn.remoteConn = remoteConn
// the ice connection has been established successfully so we are ready to start the proxy
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
remoteOfferAnswer.RosenpassAddr)
@@ -435,7 +413,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
conn.connID = nbnet.GenerateConnID()
@@ -487,6 +464,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
return nil, err
}
if runtime.GOOS == "ios" {
runtime.GC()
}
if conn.onConnected != nil {
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
}
@@ -617,40 +598,39 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
// and then signals them to the remote peer
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
if candidate != nil {
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := conn.signalCandidate(candidate)
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
relatedAdd := candidate.RelatedAddress()
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
err = conn.signalCandidate(extraSrflx)
if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
return
}
conn.sentExtraSrflx = true
}
}()
// nil means candidate gathering has been ended
if candidate == nil {
return
}
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
log.Debugf("discovered local candidate %s", candidate.String())
go func() {
err := conn.signalCandidate(candidate)
if err != nil {
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
}
}()
if !conn.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
conn.sentExtraSrflx = true
go func() {
err = conn.signalCandidate(extraSrflx)
if err != nil {
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
}
}()
}
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
@@ -775,7 +755,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
go func() {
conn.mu.Lock()
@@ -785,6 +765,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := conn.agent.AddRemoteCandidate(candidate)
if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
@@ -797,8 +781,49 @@ func (conn *Conn) GetKey() string {
return conn.config.Key
}
// RegisterProtoSupportMeta register supported proto message in the connection metadata
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
protoSupport := signal.ParseFeaturesSupported(support)
conn.meta.protoSupport = protoSupport
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: candidate.NetworkType().String(),
Address: candidate.Address(),
Port: relatedAdd.Port,
Component: candidate.Component(),
RelAddr: relatedAdd.Address,
RelPort: relatedAdd.Port,
})
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}

View File

@@ -170,7 +170,7 @@ func ProbeAll(
var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
wg.Add(1)

View File

@@ -43,11 +43,6 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.
}
if prefix.Addr().Is6() {
inet = "-inet6"
// Special case for IPv6 split default route, pointing to the wg interface fails
// TODO: Remove once we have IPv6 support on the interface
if prefix.Bits() == 1 {
intf = &net.Interface{Name: "lo0"}
}
}
args := []string{"-n", action, inet, network}

View File

@@ -1,7 +1,7 @@
<!DOCTYPE html>
<html>
<html lang="en">
<head>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<meta content="width=device-width, initial-scale=1" name="viewport"/>
<style>
body {
display: flex;
@@ -50,16 +50,17 @@
color: black;
}
</style>
<title>NetBird Login Successful</title>
</head>
<body>
<div class="container">
<div class="logo">
<img src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
<img alt="netbird_logo" src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
</div>
<br>
{{ if .Error }}
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="red" stroke-width="3"/>
<svg height="50" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
<circle cx="50" cy="50" fill="none" r="45" stroke="red" stroke-width="3"/>
<path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/>
</svg>
<div class="content">
@@ -69,8 +70,8 @@
{{ .Error }}.
</div>
{{ else }}
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
<circle cx="50" cy="50" r="45" fill="none" stroke="#5cb85c" stroke-width="3"/>
<svg height="50" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
<circle cx="50" cy="50" fill="none" r="45" stroke="#5cb85c" stroke-width="3"/>
<path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/>
</svg>
<div class="content">

View File

@@ -109,7 +109,6 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
p.cancel()
return nil
}

View File

@@ -1806,6 +1806,91 @@ func (x *DebugBundleResponse) GetPath() string {
return ""
}
type GetLogLevelRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *GetLogLevelRequest) Reset() {
*x = GetLogLevelRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[26]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetLogLevelRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetLogLevelRequest) ProtoMessage() {}
func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[26]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*GetLogLevelRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{26}
}
type GetLogLevelResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Level LogLevel `protobuf:"varint,1,opt,name=level,proto3,enum=daemon.LogLevel" json:"level,omitempty"`
}
func (x *GetLogLevelResponse) Reset() {
*x = GetLogLevelResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[27]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GetLogLevelResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetLogLevelResponse) ProtoMessage() {}
func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[27]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*GetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{27}
}
func (x *GetLogLevelResponse) GetLevel() LogLevel {
if x != nil {
return x.Level
}
return LogLevel_UNKNOWN
}
type SetLogLevelRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -1817,7 +1902,7 @@ type SetLogLevelRequest struct {
func (x *SetLogLevelRequest) Reset() {
*x = SetLogLevelRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[26]
mi := &file_daemon_proto_msgTypes[28]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1830,7 +1915,7 @@ func (x *SetLogLevelRequest) String() string {
func (*SetLogLevelRequest) ProtoMessage() {}
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[26]
mi := &file_daemon_proto_msgTypes[28]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1843,7 +1928,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{26}
return file_daemon_proto_rawDescGZIP(), []int{28}
}
func (x *SetLogLevelRequest) GetLevel() LogLevel {
@@ -1862,7 +1947,7 @@ type SetLogLevelResponse struct {
func (x *SetLogLevelResponse) Reset() {
*x = SetLogLevelResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_daemon_proto_msgTypes[27]
mi := &file_daemon_proto_msgTypes[29]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -1875,7 +1960,7 @@ func (x *SetLogLevelResponse) String() string {
func (*SetLogLevelResponse) ProtoMessage() {}
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[27]
mi := &file_daemon_proto_msgTypes[29]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -1888,7 +1973,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{27}
return file_daemon_proto_rawDescGZIP(), []int{29}
}
var File_daemon_proto protoreflect.FileDescriptor
@@ -2138,67 +2223,77 @@ var file_daemon_proto_rawDesc = []byte{
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c,
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a,
0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49,
0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09,
0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52,
0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a,
0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43,
0x45, 0x10, 0x07, 0x32, 0xee, 0x05, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e,
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70,
0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64,
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a,
0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61,
0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22,
0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01,
0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a,
0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c,
0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a,
0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41,
0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08,
0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f,
0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a,
0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f,
0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67,
0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67,
0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74,
0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f,
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55,
0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39,
0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77,
0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42,
0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47,
0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c,
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12,
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c,
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d,
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65,
0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a,
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74,
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -2214,7 +2309,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 28)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 30)
var file_daemon_proto_goTypes = []interface{}{
(LogLevel)(0), // 0: daemon.LogLevel
(*LoginRequest)(nil), // 1: daemon.LoginRequest
@@ -2243,16 +2338,18 @@ var file_daemon_proto_goTypes = []interface{}{
(*Route)(nil), // 24: daemon.Route
(*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest
(*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse
(*SetLogLevelRequest)(nil), // 27: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 28: daemon.SetLogLevelResponse
(*timestamp.Timestamp)(nil), // 29: google.protobuf.Timestamp
(*duration.Duration)(nil), // 30: google.protobuf.Duration
(*GetLogLevelRequest)(nil), // 27: daemon.GetLogLevelRequest
(*GetLogLevelResponse)(nil), // 28: daemon.GetLogLevelResponse
(*SetLogLevelRequest)(nil), // 29: daemon.SetLogLevelRequest
(*SetLogLevelResponse)(nil), // 30: daemon.SetLogLevelResponse
(*timestamp.Timestamp)(nil), // 31: google.protobuf.Timestamp
(*duration.Duration)(nil), // 32: google.protobuf.Duration
}
var file_daemon_proto_depIdxs = []int32{
19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
29, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
29, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
30, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
31, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
31, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
32, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -2260,34 +2357,37 @@ var file_daemon_proto_depIdxs = []int32{
17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
0, // 11: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
1, // 12: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
3, // 13: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
5, // 14: daemon.DaemonService.Up:input_type -> daemon.UpRequest
7, // 15: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
9, // 16: daemon.DaemonService.Down:input_type -> daemon.DownRequest
11, // 17: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
20, // 18: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 19: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
22, // 20: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
25, // 21: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
27, // 22: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
2, // 23: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
4, // 24: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
6, // 25: daemon.DaemonService.Up:output_type -> daemon.UpResponse
8, // 26: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 27: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 28: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 29: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 30: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 31: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
26, // 32: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
28, // 33: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
23, // [23:34] is the sub-list for method output_type
12, // [12:23] is the sub-list for method input_type
12, // [12:12] is the sub-list for extension type_name
12, // [12:12] is the sub-list for extension extendee
0, // [0:12] is the sub-list for field type_name
0, // 11: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
0, // 12: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
1, // 13: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
3, // 14: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
5, // 15: daemon.DaemonService.Up:input_type -> daemon.UpRequest
7, // 16: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
9, // 17: daemon.DaemonService.Down:input_type -> daemon.DownRequest
11, // 18: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
20, // 19: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
22, // 20: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
22, // 21: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
25, // 22: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
27, // 23: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
29, // 24: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
2, // 25: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
4, // 26: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
6, // 27: daemon.DaemonService.Up:output_type -> daemon.UpResponse
8, // 28: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
10, // 29: daemon.DaemonService.Down:output_type -> daemon.DownResponse
12, // 30: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
21, // 31: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
23, // 32: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
23, // 33: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
26, // 34: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
28, // 35: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
30, // 36: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
25, // [25:37] is the sub-list for method output_type
13, // [13:25] is the sub-list for method input_type
13, // [13:13] is the sub-list for extension type_name
13, // [13:13] is the sub-list for extension extendee
0, // [0:13] is the sub-list for field type_name
}
func init() { file_daemon_proto_init() }
@@ -2609,7 +2709,7 @@ func file_daemon_proto_init() {
}
}
file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelRequest); i {
switch v := v.(*GetLogLevelRequest); i {
case 0:
return &v.state
case 1:
@@ -2621,6 +2721,30 @@ func file_daemon_proto_init() {
}
}
file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetLogLevelResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_daemon_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*SetLogLevelResponse); i {
case 0:
return &v.state
@@ -2640,7 +2764,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_daemon_proto_rawDesc,
NumEnums: 1,
NumMessages: 28,
NumMessages: 30,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -40,6 +40,9 @@ service DaemonService {
// DebugBundle creates a debug bundle
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
// GetLogLevel gets the log level of the daemon
rpc GetLogLevel(GetLogLevelRequest) returns (GetLogLevelResponse) {}
// SetLogLevel sets the log level of the daemon
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
};
@@ -256,6 +259,13 @@ enum LogLevel {
TRACE = 7;
}
message GetLogLevelRequest {
}
message GetLogLevelResponse {
LogLevel level = 1;
}
message SetLogLevelRequest {
LogLevel level = 1;
}

View File

@@ -39,6 +39,8 @@ type DaemonServiceClient interface {
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
// GetLogLevel gets the log level of the daemon
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
}
@@ -141,6 +143,15 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe
return out, nil
}
func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) {
out := new(GetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
@@ -175,6 +186,8 @@ type DaemonServiceServer interface {
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
// GetLogLevel gets the log level of the daemon
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
// SetLogLevel sets the log level of the daemon
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
@@ -214,6 +227,9 @@ func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectR
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
}
func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetLogLevel not implemented")
}
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
}
@@ -410,6 +426,24 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetLogLevelRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetLogLevel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetLogLevel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SetLogLevelRequest)
if err := dec(in); err != nil {
@@ -475,6 +509,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "DebugBundle",
Handler: _DaemonService_DebugBundle_Handler,
},
{
MethodName: "GetLogLevel",
Handler: _DaemonService_GetLogLevel_Handler,
},
{
MethodName: "SetLogLevel",
Handler: _DaemonService_SetLogLevel_Handler,

View File

@@ -121,6 +121,12 @@ func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan
}
}
// GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
level := ParseLogLevel(log.GetLevel().String())
return &proto.GetLogLevelResponse{Level: level}, nil
}
// SetLogLevel sets the logging level for the server.
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
level, err := log.ParseLevel(req.Level.String())

28
client/server/log.go Normal file
View File

@@ -0,0 +1,28 @@
package server
import (
"strings"
"github.com/netbirdio/netbird/client/proto"
)
func ParseLogLevel(level string) proto.LogLevel {
switch strings.ToLower(level) {
case "panic":
return proto.LogLevel_PANIC
case "fatal":
return proto.LogLevel_FATAL
case "error":
return proto.LogLevel_ERROR
case "warn":
return proto.LogLevel_WARN
case "info":
return proto.LogLevel_INFO
case "debug":
return proto.LogLevel_DEBUG
case "trace":
return proto.LogLevel_TRACE
default:
return proto.LogLevel_UNKNOWN
}
}

View File

@@ -36,7 +36,7 @@ const (
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
defaultInitialRetryTime = 14 * 24 * time.Hour
defaultInitialRetryTime = 30 * time.Minute
defaultMaxRetryInterval = 60 * time.Minute
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7

View File

@@ -106,10 +106,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := server.NewStoreFromJson(config.Datadir, nil)
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}

View File

@@ -20,6 +20,9 @@ const OsVersionCtxKey = "OsVersion"
// OsNameCtxKey context key for operating system name
const OsNameCtxKey = "OsName"
// UiVersionCtxKey context key for user UI version
const UiVersionCtxKey = "user-agent"
type NetworkAddress struct {
NetIP netip.Prefix
Mac string

View File

@@ -28,10 +28,18 @@ func GetInfo(ctx context.Context) *Info {
kernelVersion = osInfo[2]
}
gio := &Info{Kernel: kernel, Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: kernelVersion}
gio.Hostname = extractDeviceName(ctx, "android")
gio.WiretrusteeVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
gio := &Info{
GoOS: runtime.GOOS,
Kernel: kernel,
Platform: "unknown",
OS: "android",
OSVersion: osVersion(),
Hostname: extractDeviceName(ctx, "android"),
CPUs: runtime.NumCPU(),
WiretrusteeVersion: version.NetbirdVersion(),
UIVersion: extractUIVersion(ctx),
KernelVersion: kernelVersion,
}
return gio
}
@@ -45,6 +53,14 @@ func osVersion() string {
return run("/system/bin/getprop", "ro.build.version.release")
}
func extractUIVersion(ctx context.Context) string {
v, ok := ctx.Value(UiVersionCtxKey).(string)
if !ok {
return ""
}
return v
}
func run(name string, arg ...string) string {
cmd := exec.Command(name, arg...)
cmd.Stdin = strings.NewReader("some")

View File

@@ -399,6 +399,7 @@ func (s *serviceClient) updateStatus() error {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
s.setDisconnectedStatus()
return err
}
@@ -426,17 +427,7 @@ func (s *serviceClient) updateStatus() error {
s.mRoutes.Enable()
systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
s.connected = false
if s.isUpdateIconActive {
systray.SetIcon(s.icUpdateDisconnected)
} else {
systray.SetIcon(s.icDisconnected)
}
systray.SetTooltip("NetBird (Disconnected)")
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
s.mRoutes.Disable()
s.setDisconnectedStatus()
systrayIconState = false
}
@@ -481,6 +472,20 @@ func (s *serviceClient) updateStatus() error {
return nil
}
func (s *serviceClient) setDisconnectedStatus() {
s.connected = false
if s.isUpdateIconActive {
systray.SetIcon(s.icUpdateDisconnected)
} else {
systray.SetIcon(s.icDisconnected)
}
systray.SetTooltip("NetBird (Disconnected)")
s.mStatus.SetTitle("Disconnected")
s.mDown.Disable()
s.mUp.Enable()
s.mRoutes.Disable()
}
func (s *serviceClient) onTrayReady() {
systray.SetIcon(s.icDisconnected)
systray.SetTooltip("NetBird")

150
go.mod
View File

@@ -1,33 +1,31 @@
module github.com/netbirdio/netbird
go 1.21
toolchain go1.21.0
go 1.21.0
require (
cunicu.li/go-rosenpass v0.4.0
github.com/cenkalti/backoff/v4 v4.1.3
github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang/protobuf v1.5.3
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.0
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.18.1
github.com/onsi/gomega v1.27.6
github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
golang.org/x/crypto v0.21.0
golang.org/x/sys v0.18.0
github.com/vishvananda/netlink v1.2.1-beta.2
golang.org/x/crypto v0.23.0
golang.org/x/sys v0.20.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.56.3
google.golang.org/protobuf v1.31.0
google.golang.org/grpc v1.64.0
google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
)
@@ -35,7 +33,7 @@ require (
fyne.io/fyne/v2 v2.1.4
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.11.0
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/v3 v3.1.1
@@ -44,23 +42,24 @@ require (
github.com/gliderlabs/ssh v0.3.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.9
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1
github.com/gorilla/websocket v1.5.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.5
github.com/magiconair/properties v1.8.7
github.com/mattn/go-sqlite3 v1.14.19
github.com/mdlayher/socket v0.4.1
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
@@ -68,42 +67,60 @@ require (
github.com/pion/stun/v2 v2.0.0
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.14.0
github.com/prometheus/client_golang v1.19.1
github.com/quic-go/quic-go v0.45.0
github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.8.4
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
github.com/things-go/go-socks5 v0.0.4
github.com/yusufpapurcu/wmi v1.2.3
github.com/yusufpapurcu/wmi v1.2.4
github.com/zcalusic/sysinfo v1.0.2
go.opentelemetry.io/otel v1.11.1
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
go.opentelemetry.io/otel/metric v0.33.0
go.opentelemetry.io/otel/sdk/metric v0.33.0
go.opentelemetry.io/otel v1.26.0
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.26.0
go.opentelemetry.io/otel/sdk/metric v1.26.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
golang.org/x/net v0.23.0
golang.org/x/oauth2 v0.8.0
golang.org/x/sync v0.3.0
golang.org/x/term v0.18.0
google.golang.org/api v0.126.0
golang.org/x/net v0.25.0
golang.org/x/oauth2 v0.19.0
golang.org/x/sync v0.7.0
golang.org/x/term v0.20.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.4
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
nhooyr.io/websocket v1.8.11
)
require (
cloud.google.com/go/compute v1.19.3 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
github.com/BurntSushi/toml v1.2.1 // indirect
cloud.google.com/go/auth v0.3.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.16 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/cpuguy83/dockercfg v0.3.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v26.1.3+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
@@ -113,59 +130,84 @@ require (
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/sys/user v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.3.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.53.0 // indirect
github.com/prometheus/procfs v0.15.0 // indirect
github.com/shirou/gopsutil/v3 v3.24.4 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/yuin/goldmark v1.4.13 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/image v0.10.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.21.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
k8s.io/apimachinery v0.23.16 // indirect
k8s.io/apimachinery v0.26.2 // indirect
)
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0

663
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ package iface
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
@@ -79,8 +80,19 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Error(err)
}
assert.Equal(t, addr, addrs[0].String())
var found bool
for _, a := range addrs {
prefix, err := netip.ParsePrefix(a.String())
assert.NoError(t, err)
if prefix.Addr().Is4() {
found = true
assert.Equal(t, addr, prefix.String())
}
}
if !found {
t.Fatal("v4 address not found")
}
}
func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {

View File

@@ -1,5 +1,4 @@
//go:build !ios
// +build !ios
package iface
@@ -121,13 +120,19 @@ func (t *tunDevice) Wrapper() *DeviceWrapper {
func (t *tunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Infof(`adding address command "%v" failed with output %s and error: `, cmd.String(), out)
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
return err
}
// dummy ipv6 so routing works
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
if out, err := routeCmd.CombinedOutput(); err != nil {
log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out)
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
return err
}
return nil

View File

@@ -62,10 +62,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}

View File

@@ -7,10 +7,11 @@ import (
"os"
"path"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
)
var shortDown = "Rollback SQLite store to JSON file store. Please make a backup of the SQLite file before running this command."
@@ -39,16 +40,16 @@ var downCmd = &cobra.Command{
return fmt.Errorf("%s already exists, couldn't continue the operation", fileStorePath)
}
sqlstore, err := server.NewSqliteStore(mgmtDataDir, nil)
sqlStore, err := server.NewSqliteStore(mgmtDataDir, nil)
if err != nil {
return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err)
}
sqliteStoreAccounts := len(sqlstore.GetAllAccounts())
sqliteStoreAccounts := len(sqlStore.GetAllAccounts())
log.Infof("%d account will be migrated from sqlite store %s to file store %s",
sqliteStoreAccounts, sqliteStorePath, fileStorePath)
store, err := server.NewFilestoreFromSqliteStore(sqlstore, mgmtDataDir, nil)
store, err := server.NewFilestoreFromSqliteStore(sqlStore, mgmtDataDir, nil)
if err != nil {
return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err)
}

View File

@@ -132,6 +132,7 @@ type AccountManager interface {
GetValidatedPeers(account *Account) (map[string]struct{}, error)
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
CancelPeerRoutines(peer *nbpeer.Peer) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}
type DefaultAccountManager struct {
@@ -241,6 +242,11 @@ type Account struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
// Subclass used in gorm to only load settings and not whole account
type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -1768,6 +1774,8 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
log.Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
@@ -1788,8 +1796,10 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
}
}
start := time.Now()
unlock := am.Store.AcquireGlobalLock()
defer unlock()
log.Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
@@ -1840,6 +1850,9 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
}
return nil, nil, err
}
@@ -1853,7 +1866,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.I
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
if err != nil {
return nil, nil, mapError(err)
return nil, nil, err
}
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
@@ -1867,6 +1880,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.I
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
return status.Errorf(status.Unauthenticated, "peer not registered")
}
return err
}
@@ -1951,6 +1967,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
am.updateAccountPeers(updatedAccount)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
return am.Store.GetPostureCheckByChecksDefinition(accountID, checks)
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {

View File

@@ -48,8 +48,8 @@ func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, p
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) {
return false, false
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
@@ -1294,6 +1294,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err)
return
}
userID := "account_creator"
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
if err != nil {
@@ -1655,6 +1656,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
@@ -1707,6 +1709,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
@@ -1750,6 +1753,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
require.NoError(t, err, "unable to create an account")
@@ -2267,21 +2271,29 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
func createManager(t *testing.T) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
if err != nil {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, err
}
return manager, nil
}
func createStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, err := NewStoreFromJson(dataDir, nil)
store, cleanUp, err := NewTestStoreFromJson(dataDir)
if err != nil {
return nil, err
}
t.Cleanup(cleanUp)
return store, nil
}

View File

@@ -32,7 +32,7 @@ func TestGetDNSSettings(t *testing.T) {
account, err := initTestDNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatal("failed to init testing account")
}
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
@@ -200,10 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, err := NewStoreFromJson(dataDir, nil)
store, cleanUp, err := NewTestStoreFromJson(dataDir)
if err != nil {
return nil, err
}
t.Cleanup(cleanUp)
return store, nil
}

View File

@@ -12,6 +12,7 @@ import (
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -57,18 +58,18 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
}
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
func NewFilestoreFromSqliteStore(sqlitestore *SqliteStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(sqlitestore.GetInstallationID())
err = store.SaveInstallationID(sqlStore.GetInstallationID())
if err != nil {
return nil, err
}
for _, account := range sqlitestore.GetAllAccounts() {
for _, account := range sqlStore.GetAllAccounts() {
store.Accounts[account.Id] = account
}
@@ -508,7 +509,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return nil, status.Errorf(status.NotFound, "account not found")
return nil, status.NewUserNotFoundError(userID)
}
account, err := s.getAccount(accountID)
@@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
if _, ok := account.Peers[peerID]; !ok {
delete(s.PeerID2AccountID, peerID)
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerID)
return nil, status.NewPeerNotFoundError(peerID)
}
return account.Copy(), nil
@@ -552,7 +553,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
return nil, status.NewPeerNotFoundError(peerKey)
}
account, err := s.getAccount(accountID)
@@ -572,7 +573,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
if stale {
delete(s.PeerKeyID2AccountID, peerKey)
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerKey)
return nil, status.NewPeerNotFoundError(peerKey)
}
return account.Copy(), nil
@@ -584,12 +585,71 @@ func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
return "", status.NewPeerNotFoundError(peerKey)
}
return accountID, nil
}
func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.UserID2AccountID[userID]
if !ok {
return "", status.NewUserNotFoundError(userID)
}
return accountID, nil
}
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
accountID, ok := s.PeerKeyID2AccountID[peerKey]
if !ok {
return nil, status.NewPeerNotFoundError(peerKey)
}
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
for _, peer := range account.Peers {
if peer.Key == peerKey {
return peer.Copy(), nil
}
}
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.getAccount(accountID)
if err != nil {
return nil, err
}
return account.Settings.Copy(), nil
}
// GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string {
return s.InstallationID
@@ -667,6 +727,10 @@ func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.T
return nil
}
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
s.mux.Lock()

View File

@@ -59,6 +59,7 @@ func TestStalePeerIndices(t *testing.T) {
func TestNewStore(t *testing.T) {
store := newStore(t)
defer store.Close()
if store.Accounts == nil || len(store.Accounts) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@@ -87,6 +88,7 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
@@ -135,6 +137,8 @@ func TestDeleteAccount(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer store.Close()
var account *Account
for _, a := range store.Accounts {
account = a
@@ -179,6 +183,7 @@ func TestDeleteAccount(t *testing.T) {
func TestStore(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId("account_id", "testuser", "")
account.Peers["testpeer"] = &nbpeer.Peer{
@@ -436,6 +441,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
store := newStore(t)
defer store.Close()
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")

View File

@@ -245,6 +245,11 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
return &GroupLinkError{"route", string(r.NetID)}
}
}
for _, g := range r.PeerGroups {
if g == groupID {
return &GroupLinkError{"route", string(r.NetID)}
}
}
}
// check DNS links

View File

@@ -70,6 +70,11 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
"grp-for-route",
"route",
},
{
"route with peer groups",
"grp-for-route2",
"route",
},
{
"name server groups",
"grp-for-name-server-grp",
@@ -138,6 +143,14 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
Peers: make([]string, 0),
}
groupForRoute2 := &nbgroup.Group{
ID: "grp-for-route2",
AccountID: "account-id",
Name: "Group for route",
Issued: nbgroup.GroupIssuedAPI,
Peers: make([]string, 0),
}
groupForNameServerGroups := &nbgroup.Group{
ID: "grp-for-name-server-grp",
AccountID: "account-id",
@@ -183,6 +196,11 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
Groups: []string{groupForRoute.ID},
}
routePeerGroupResource := &route.Route{
ID: "example route with peer groups",
PeerGroups: []string{groupForRoute2.ID},
}
nameServerGroup := &nbdns.NameServerGroup{
ID: "example name server group",
Groups: []string{groupForNameServerGroups.ID},
@@ -209,6 +227,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
}
account := newAccountWithId(accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
@@ -220,6 +239,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
}
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)

View File

@@ -136,7 +136,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
if err != nil {
return err
return mapError(err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
@@ -368,7 +368,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
})
if err != nil {
log.Warnf("failed logging in peer %s", peerKey)
log.Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(err)
}

View File

@@ -3,12 +3,10 @@ package http
import (
"encoding/json"
"net/http"
"net/netip"
"regexp"
"slices"
"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -59,7 +57,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
postureChecks := []*api.PostureCheck{}
for _, postureCheck := range accountPostureChecks {
postureChecks = append(postureChecks, toPostureChecksResponse(postureCheck))
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}
util.WriteJSONObject(w, postureChecks)
@@ -130,7 +128,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return
}
util.WriteJSONObject(w, toPostureChecksResponse(postureChecks))
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}
// DeletePostureCheck handles posture check deletion request
@@ -178,55 +176,26 @@ func (p *PostureChecksHandler) savePostureChecks(
return
}
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := posture.Checks{
ID: postureChecksID,
Name: req.Name,
Description: req.Description,
}
if nbVersionCheck := req.Checks.NbVersionCheck; nbVersionCheck != nil {
postureChecks.Checks.NBVersionCheck = &posture.NBVersionCheck{
MinVersion: nbVersionCheck.MinVersion,
}
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
postureChecks.Checks.OSVersionCheck = &posture.OSVersionCheck{
Android: (*posture.MinVersionCheck)(osVersionCheck.Android),
Darwin: (*posture.MinVersionCheck)(osVersionCheck.Darwin),
Ios: (*posture.MinVersionCheck)(osVersionCheck.Ios),
Linux: (*posture.MinKernelVersionCheck)(osVersionCheck.Linux),
Windows: (*posture.MinKernelVersionCheck)(osVersionCheck.Windows),
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
}
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w)
return
}
}
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
}
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
@@ -294,105 +263,3 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
return nil
}
func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks
if postureChecks.Checks.NBVersionCheck != nil {
checks.NbVersionCheck = &api.NBVersionCheck{
MinVersion: postureChecks.Checks.NBVersionCheck.MinVersion,
}
}
if postureChecks.Checks.OSVersionCheck != nil {
checks.OsVersionCheck = &api.OSVersionCheck{
Android: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Android),
Darwin: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Darwin),
Ios: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Ios),
Linux: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Linux),
Windows: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Windows),
}
}
if postureChecks.Checks.GeoLocationCheck != nil {
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
}
if postureChecks.Checks.PeerNetworkRangeCheck != nil {
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
}
return &api.PostureCheck{
Id: postureChecks.ID,
Name: postureChecks.Name,
Description: &postureChecks.Description,
Checks: checks,
}
}
func toGeoLocationCheckResponse(geoLocationCheck *posture.GeoLocationCheck) *api.GeoLocationCheck {
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
for _, loc := range geoLocationCheck.Locations {
l := loc // make G601 happy
var cityName *string
if loc.CityName != "" {
cityName = &l.CityName
}
locations = append(locations, api.Location{
CityName: cityName,
CountryCode: loc.CountryCode,
})
}
return &api.GeoLocationCheck{
Action: api.GeoLocationCheckAction(geoLocationCheck.Action),
Locations: locations,
}
}
func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *posture.GeoLocationCheck {
locations := make([]posture.Location, 0, len(apiGeoLocationCheck.Locations))
for _, loc := range apiGeoLocationCheck.Locations {
cityName := ""
if loc.CityName != nil {
cityName = *loc.CityName
}
locations = append(locations, posture.Location{
CountryCode: loc.CountryCode,
CityName: cityName,
})
}
return &posture.GeoLocationCheck{
Action: string(apiGeoLocationCheck.Action),
Locations: locations,
}
}
func toPeerNetworkRangeCheckResponse(check *posture.PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
netPrefixes := make([]string, 0, len(check.Ranges))
for _, netPrefix := range check.Ranges {
netPrefixes = append(netPrefixes, netPrefix.String())
}
return &api.PeerNetworkRangeCheck{
Ranges: netPrefixes,
Action: api.PeerNetworkRangeCheckAction(check.Action),
}
}
func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNetworkRangeCheck, error) {
prefixes := make([]netip.Prefix, 0)
for _, prefix := range check.Ranges {
parsedPrefix, err := netip.ParsePrefix(prefix)
if err != nil {
return nil, err
}
prefixes = append(prefixes, parsedPrefix)
}
return &posture.PeerNetworkRangeCheck{
Ranges: prefixes,
Action: string(check.Action),
}, nil
}

View File

@@ -154,7 +154,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
data.Set("client_id", zc.clientConfig.ClientID)
data.Set("client_secret", zc.clientConfig.ClientSecret)
data.Set("grant_type", zc.clientConfig.GrantType)
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
data.Set("scope", "openid urn:zitadel:iam:org:project:id:zitadel:aud")
payload := strings.NewReader(data.Encode())
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)

View File

@@ -11,7 +11,7 @@ type IntegratedValidator interface {
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool)
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string))

View File

@@ -405,10 +405,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, err := NewStoreFromJson(config.Datadir, nil)
store, cleanUp, err := NewTestStoreFromJson(config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",

View File

@@ -469,8 +469,8 @@ func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, p
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) {
return false, false
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
@@ -532,10 +532,11 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
store, err := server.NewStoreFromJson(config.Datadir, nil)
store, _, err := server.NewTestStoreFromJson(config.Datadir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",

View File

@@ -26,7 +26,7 @@ const (
// defaultPushInterval default interval to push metrics
defaultPushInterval = 24 * time.Hour
// requestTimeout http request timeout
requestTimeout = 30 * time.Second
requestTimeout = 45 * time.Second
)
type getTokenResponse struct {
@@ -98,10 +98,7 @@ func (w *Worker) Run() {
}
func (w *Worker) sendMetrics() error {
ctx, cancel := context.WithTimeout(w.ctx, requestTimeout)
defer cancel()
apiKey, err := getAPIKey(ctx)
apiKey, err := getAPIKey(w.ctx)
if err != nil {
return err
}
@@ -115,7 +112,7 @@ func (w *Worker) sendMetrics() error {
httpClient := http.Client{}
exportJobReq, err := createPostRequest(ctx, payloadEndpoint+"/capture/", payloadString)
exportJobReq, err := createPostRequest(w.ctx, payloadEndpoint+"/capture/", payloadString)
if err != nil {
return fmt.Errorf("unable to create metrics post request %v", err)
}
@@ -328,6 +325,8 @@ func (w *Worker) generateProperties() properties {
}
func getAPIKey(ctx context.Context) (string, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
httpClient := http.Client{}
@@ -375,6 +374,8 @@ func buildMetricsPayload(payload pushPayload) (string, error) {
}
func createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
reqURL := endpoint
payload := strings.NewReader(payloadStr)

View File

@@ -95,6 +95,7 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
@@ -734,3 +735,11 @@ func (am *MockAccountManager) GroupValidation(accountId string, groups []string)
}
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
}
// FindExistingPostureCheck mocks FindExistingPostureCheck of the AccountManager interface
func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
if am.FindExistingPostureCheckFunc != nil {
return am.FindExistingPostureCheckFunc(accountID, checks)
}
return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented")
}

View File

@@ -766,10 +766,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, err := NewStoreFromJson(dataDir, nil)
store, cleanUp, err := NewTestStoreFromJson(dataDir)
if err != nil {
return nil, err
}
t.Cleanup(cleanUp)
return store, nil
}

View File

@@ -335,24 +335,29 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
}
upperKey := strings.ToUpper(setupKey)
var account *Account
var accountID string
var err error
addedByUser := false
if len(userID) > 0 {
addedByUser = true
account, err = am.Store.GetAccountByUser(userID)
accountID, err = am.Store.GetAccountIDByUserID(userID)
} else {
account, err = am.Store.GetAccountBySetupKey(setupKey)
accountID, err = am.Store.GetAccountIDBySetupKey(setupKey)
}
if err != nil {
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
unlock := am.Store.AcquireAccountWriteLock(accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
account, err = am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}
@@ -485,6 +490,10 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
return nil, nil, err
}
// Account is saved, we can release the lock
unlock()
unlock = nil
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
@@ -507,7 +516,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
return nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
@@ -515,11 +524,15 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
return nil, nil, err
}
if peerLoginExpired(peer, account) {
if peerLoginExpired(peer, account.Settings) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, err
}
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
@@ -541,7 +554,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey)
accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
@@ -570,19 +583,59 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
return nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
}
// we found the peer, and we follow a normal login flow
unlock := am.Store.AcquireAccountWriteLock(account.Id)
defer unlock()
peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.NewPeerNotRegisteredError()
}
accSettings, err := am.Store.GetAccountSettings(accountID)
if err != nil {
return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
}
var isWriteLock bool
// duplicated logic from after the lock to have an early exit
expired := peerLoginExpired(peer, accSettings)
switch {
case expired:
if err := checkAuth(login.UserID, peer); err != nil {
return nil, nil, err
}
isWriteLock = true
log.Debugf("peer login expired, acquiring write lock")
case peer.UpdateMetaIfNew(login.Meta):
isWriteLock = true
log.Debugf("peer changed meta, acquiring write lock")
default:
isWriteLock = false
log.Debugf("peer meta is the same, acquiring read lock")
}
var unlock func()
if isWriteLock {
unlock = am.Store.AcquireAccountWriteLock(accountID)
} else {
unlock = am.Store.AcquireAccountReadLock(accountID)
}
defer func() {
if unlock != nil {
unlock()
}
}()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err = am.Store.GetAccount(account.Id)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, nil, err
}
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
return nil, nil, status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, account)
@@ -593,7 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
// this flag prevents unnecessary calls to the persistent store.
shouldStoreAccount := false
updateRemotePeers := false
if peerLoginExpired(peer, account) {
if peerLoginExpired(peer, account.Settings) {
err = checkAuth(login.UserID, peer)
if err != nil {
return nil, nil, err
@@ -614,7 +667,10 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
}
isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, err
}
peer, updated := updatePeerMeta(peer, login.Meta, account)
if updated {
shouldStoreAccount = true
@@ -626,11 +682,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}
if shouldStoreAccount {
if !isWriteLock {
log.Errorf("account %s should be stored but is not write locked", accountID)
return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
}
err = am.Store.SaveAccount(account)
if err != nil {
return nil, nil, err
}
}
unlock()
unlock = nil
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(account)
@@ -676,9 +738,9 @@ func checkAuth(loginUserID string, peer *nbpeer.Peer) error {
return nil
}
func peerLoginExpired(peer *nbpeer.Peer, account *Account) bool {
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool {
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
expired = settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired {
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
return true

View File

@@ -216,7 +216,9 @@ func (p *Peer) FQDN(dnsDomain string) string {
// EventMeta returns activity event meta related to the peer
func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt}
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt,
"location_city_name": p.Location.CityName, "location_country_code": p.Location.CountryCode,
"location_geo_name_id": p.Location.GeoNameID, "location_connection_ip": p.Location.ConnectionIP}
}
// Copy PeerStatus

View File

@@ -148,7 +148,7 @@ type Policy struct {
Enabled bool
// Rules of the policy
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id"`
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`

View File

@@ -5,8 +5,11 @@ import (
"net/netip"
"github.com/hashicorp/go-version"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
)
const (
@@ -136,6 +139,96 @@ func (pc *Checks) GetChecks() []Check {
return checks
}
func NewChecksFromAPIPostureCheck(source api.PostureCheck) (*Checks, error) {
description := ""
if source.Description != nil {
description = *source.Description
}
return buildPostureCheck(source.Id, source.Name, description, source.Checks)
}
func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureChecksID string) (*Checks, error) {
return buildPostureCheck(postureChecksID, source.Name, source.Description, *source.Checks)
}
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := Checks{
ID: postureChecksID,
Name: name,
Description: description,
}
if nbVersionCheck := checks.NbVersionCheck; nbVersionCheck != nil {
postureChecks.Checks.NBVersionCheck = &NBVersionCheck{
MinVersion: nbVersionCheck.MinVersion,
}
}
if osVersionCheck := checks.OsVersionCheck; osVersionCheck != nil {
postureChecks.Checks.OSVersionCheck = &OSVersionCheck{
Android: (*MinVersionCheck)(osVersionCheck.Android),
Darwin: (*MinVersionCheck)(osVersionCheck.Darwin),
Ios: (*MinVersionCheck)(osVersionCheck.Ios),
Linux: (*MinKernelVersionCheck)(osVersionCheck.Linux),
Windows: (*MinKernelVersionCheck)(osVersionCheck.Windows),
}
}
if geoLocationCheck := checks.GeoLocationCheck; geoLocationCheck != nil {
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
}
var err error
if peerNetworkRangeCheck := checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "invalid network prefix")
}
}
return &postureChecks, nil
}
func (pc *Checks) ToAPIResponse() *api.PostureCheck {
var checks api.Checks
if pc.Checks.NBVersionCheck != nil {
checks.NbVersionCheck = &api.NBVersionCheck{
MinVersion: pc.Checks.NBVersionCheck.MinVersion,
}
}
if pc.Checks.OSVersionCheck != nil {
checks.OsVersionCheck = &api.OSVersionCheck{
Android: (*api.MinVersionCheck)(pc.Checks.OSVersionCheck.Android),
Darwin: (*api.MinVersionCheck)(pc.Checks.OSVersionCheck.Darwin),
Ios: (*api.MinVersionCheck)(pc.Checks.OSVersionCheck.Ios),
Linux: (*api.MinKernelVersionCheck)(pc.Checks.OSVersionCheck.Linux),
Windows: (*api.MinKernelVersionCheck)(pc.Checks.OSVersionCheck.Windows),
}
}
if pc.Checks.GeoLocationCheck != nil {
checks.GeoLocationCheck = toGeoLocationCheckResponse(pc.Checks.GeoLocationCheck)
}
if pc.Checks.PeerNetworkRangeCheck != nil {
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(pc.Checks.PeerNetworkRangeCheck)
}
return &api.PostureCheck{
Id: pc.ID,
Name: pc.Name,
Description: &pc.Description,
Checks: checks,
}
}
func (pc *Checks) Validate() error {
if check := pc.Checks.NBVersionCheck; check != nil {
if !isVersionValid(check.MinVersion) {
@@ -192,3 +285,70 @@ func isVersionValid(ver string) bool {
return false
}
func toGeoLocationCheckResponse(geoLocationCheck *GeoLocationCheck) *api.GeoLocationCheck {
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
for _, loc := range geoLocationCheck.Locations {
l := loc // make G601 happy
var cityName *string
if loc.CityName != "" {
cityName = &l.CityName
}
locations = append(locations, api.Location{
CityName: cityName,
CountryCode: loc.CountryCode,
})
}
return &api.GeoLocationCheck{
Action: api.GeoLocationCheckAction(geoLocationCheck.Action),
Locations: locations,
}
}
func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *GeoLocationCheck {
locations := make([]Location, 0, len(apiGeoLocationCheck.Locations))
for _, loc := range apiGeoLocationCheck.Locations {
cityName := ""
if loc.CityName != nil {
cityName = *loc.CityName
}
locations = append(locations, Location{
CountryCode: loc.CountryCode,
CityName: cityName,
})
}
return &GeoLocationCheck{
Action: string(apiGeoLocationCheck.Action),
Locations: locations,
}
}
func toPeerNetworkRangeCheckResponse(check *PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
netPrefixes := make([]string, 0, len(check.Ranges))
for _, netPrefix := range check.Ranges {
netPrefixes = append(netPrefixes, netPrefix.String())
}
return &api.PeerNetworkRangeCheck{
Ranges: netPrefixes,
Action: api.PeerNetworkRangeCheckAction(check.Action),
}
}
func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*PeerNetworkRangeCheck, error) {
prefixes := make([]netip.Prefix, 0)
for _, prefix := range check.Ranges {
parsedPrefix, err := netip.ParsePrefix(prefix)
if err != nil {
return nil, err
}
prefixes = append(prefixes, parsedPrefix)
}
return &PeerNetworkRangeCheck{
Ranges: prefixes,
Action: string(check.Action),
}, nil
}

View File

@@ -3,8 +3,9 @@ package server
import (
"testing"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/posture"
)
const (

View File

@@ -1021,10 +1021,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
func createRouterStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, err := NewStoreFromJson(dataDir, nil)
store, cleanUp, err := NewTestStoreFromJson(dataDir)
if err != nil {
return nil, err
}
t.Cleanup(cleanUp)
return store, nil
}

View File

@@ -1,10 +1,9 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"path/filepath"
"runtime"
"strings"
@@ -12,6 +11,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@@ -20,7 +20,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
@@ -28,14 +27,14 @@ import (
"github.com/netbirdio/netbird/route"
)
// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk
type SqliteStore struct {
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
db *gorm.DB
storeFile string
accountLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine StoreEngine
}
type installation struct {
@@ -45,24 +44,8 @@ type installation struct {
type migrationFunc func(*gorm.DB) error
// NewSqliteStore restores a store from the file located in the datadir
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
storeStr := "store.db?cache=shared"
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = "store.db"
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
})
if err != nil {
return nil, err
}
// NewSqlStore creates a new SqlStore instance.
func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB()
if err != nil {
return nil, err
@@ -82,33 +65,11 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
return nil, fmt.Errorf("auto migrate: %w", err)
}
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir
func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
store, err := NewSqliteStore(dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(filestore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range filestore.GetAllAccounts() {
err := store.SaveAccount(account)
if err != nil {
return nil, err
}
}
return store, nil
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
func (s *SqlStore) AcquireGlobalLock() (unlock func()) {
log.Tracef("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
@@ -127,7 +88,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
return unlock
}
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Tracef("acquiring write lock for account %s", accountID)
start := time.Now()
@@ -143,7 +104,7 @@ func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func())
return unlock
}
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) {
log.Tracef("acquiring read lock for account %s", accountID)
start := time.Now()
@@ -159,7 +120,7 @@ func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return unlock
}
func (s *SqliteStore) SaveAccount(account *Account) error {
func (s *SqlStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
@@ -225,12 +186,12 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
log.Debugf("took %d ms to persist an account to the store", took.Milliseconds())
return err
}
func (s *SqliteStore) DeleteAccount(account *Account) error {
func (s *SqlStore) DeleteAccount(account *Account) error {
start := time.Now()
err := s.db.Transaction(func(tx *gorm.DB) error {
@@ -256,19 +217,19 @@ func (s *SqliteStore) DeleteAccount(account *Account) error {
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to delete an account to the SQLite", took.Milliseconds())
log.Debugf("took %d ms to delete an account to the store", took.Milliseconds())
return err
}
func (s *SqliteStore) SaveInstallationID(ID string) error {
func (s *SqlStore) SaveInstallationID(ID string) error {
installation := installation{InstallationIDValue: ID}
installation.ID = uint(s.installationPK)
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
}
func (s *SqliteStore) GetInstallationID() string {
func (s *SqlStore) GetInstallationID() string {
var installation installation
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
@@ -278,7 +239,7 @@ func (s *SqliteStore) GetInstallationID() string {
return installation.InstallationIDValue
}
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
@@ -296,7 +257,7 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer
return nil
}
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
@@ -318,17 +279,17 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee
return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil
}
// DeleteTokenID2UserIDIndex is noop in Sqlite
func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error {
// DeleteTokenID2UserIDIndex is noop in SqlStore
func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
return nil
}
func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
var account Account
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
@@ -345,7 +306,7 @@ func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error)
return s.GetAccount(account.Id)
}
func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
if result.Error != nil {
@@ -363,7 +324,7 @@ func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
return s.GetAccount(key.AccountID)
}
func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
func (s *SqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
@@ -377,7 +338,7 @@ func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error
return token.ID, nil
}
func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, "id = ?", tokenID)
if result.Error != nil {
@@ -406,7 +367,7 @@ func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
return &user, nil
}
func (s *SqliteStore) GetAllAccounts() (all []*Account) {
func (s *SqlStore) GetAllAccounts() (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
if result.Error != nil {
@@ -422,7 +383,7 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
return all
}
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
func (s *SqlStore) GetAccount(accountID string) (*Account, error) {
var account Account
result := s.db.Model(&account).
@@ -430,7 +391,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
Preload(clause.Associations).
First(&account, "id = ?", accountID)
if result.Error != nil {
log.Errorf("error when getting account from the store: %s", result.Error)
log.Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
}
@@ -490,14 +451,13 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
return &account, nil
}
func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, "id = ?", userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -508,7 +468,7 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
return s.GetAccount(user.AccountID)
}
func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
if result.Error != nil {
@@ -526,7 +486,7 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
return s.GetAccount(peer.AccountID)
}
func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
@@ -545,7 +505,7 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
return s.GetAccount(peer.AccountID)
}
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
@@ -560,8 +520,63 @@ func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
return accountID, nil
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string
result := s.db.Model(&user).Select("account_id").Where("id = ?", userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return "", status.Errorf(status.Internal, "issue getting account from store")
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
var key SetupKey
var accountID string
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting setup key from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting setup key from store")
}
return accountID, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
result := s.db.First(&peer, "key = ?", peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Model(&Account{}).Where("id = ?", accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
log.Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
@@ -569,7 +584,6 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
}
log.Errorf("error when getting user from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting user from store")
}
@@ -578,8 +592,23 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time
return s.db.Save(user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
definitionJSON, err := json.Marshal(checks)
if err != nil {
return nil, err
}
var postureCheck posture.Checks
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
if err != nil {
return nil, err
}
return &postureCheck, nil
}
// Close closes the underlying DB connection
func (s *SqliteStore) Close() error {
func (s *SqlStore) Close() error {
sql, err := s.db.DB()
if err != nil {
return fmt.Errorf("get db: %w", err)
@@ -587,40 +616,85 @@ func (s *SqliteStore) Close() error {
return sql.Close()
}
// GetStoreEngine returns SqliteStoreEngine
func (s *SqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine
// GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() StoreEngine {
return s.storeEngine
}
// migrate migrates the SQLite database to the latest schema
func migrate(db *gorm.DB) error {
migrations := getMigrations()
// NewSqliteStore creates a new SQLite store.
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
storeStr := "store.db?cache=shared"
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = "store.db"
}
for _, m := range migrations {
if err := m(db); err != nil {
return err
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: true,
})
if err != nil {
return nil, err
}
return NewSqlStore(db, SqliteStoreEngine, metrics)
}
// NewPostgresqlStore creates a new Postgres store.
func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
if err != nil {
return nil, err
}
return NewSqlStore(db, PostgresStoreEngine, metrics)
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewSqliteStore(dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(fileStore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range fileStore.GetAllAccounts() {
err := store.SaveAccount(account)
if err != nil {
return nil, err
}
}
return nil
return store, nil
}
func getMigrations() []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
},
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
func NewPostgresqlStoreFromFileStore(fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewPostgresqlStore(dsn, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(fileStore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range fileStore.GetAllAccounts() {
err := store.SaveAccount(account)
if err != nil {
return nil, err
}
}
return store, nil
}

View File

@@ -5,6 +5,7 @@ import (
"math/rand"
"net"
"net/netip"
"os"
"path/filepath"
"runtime"
"testing"
@@ -12,6 +13,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -569,7 +571,7 @@ func TestMigrate(t *testing.T) {
require.NoError(t, err, "Migration should not fail on migrated db")
}
func newSqliteStore(t *testing.T) *SqliteStore {
func newSqliteStore(t *testing.T) *SqlStore {
t.Helper()
store, err := NewSqliteStore(t.TempDir(), nil)
@@ -579,7 +581,7 @@ func newSqliteStore(t *testing.T) *SqliteStore {
return store
}
func newSqliteStoreFromFile(t *testing.T, filename string) *SqliteStore {
func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore {
t.Helper()
storeDir := t.TempDir()
@@ -613,3 +615,298 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account)
}
func newPostgresqlStore(t *testing.T) *SqlStore {
t.Helper()
cleanUp, err := testutil.CreatePGDB()
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
}
store, err := NewPostgresqlStore(postgresDsn, nil)
if err != nil {
t.Fatalf("could not initialize postgresql store: %s", err)
}
require.NoError(t, err)
require.NotNil(t, store)
return store
}
func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore {
t.Helper()
storeDir := t.TempDir()
err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
require.NoError(t, err)
fStore, err := NewFileStore(storeDir, nil)
require.NoError(t, err)
cleanUp, err := testutil.CreatePGDB()
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
}
store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil)
require.NoError(t, err)
require.NotNil(t, store)
return store
}
func TestPostgresql_NewStore(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStore(t)
if len(store.GetAllAccounts()) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestPostgresql_SaveAccount(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStore(t)
account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err := store.SaveAccount(account)
require.NoError(t, err)
account2 := newAccountWithId("account_id2", "testuser2", "")
setupKey = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
SetupKey: "peerkeysetupkey2",
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account2)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil {
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByUser("testuser"); a == nil {
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByPeerID("testpeer"); a == nil {
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil {
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
}
}
func TestPostgresql_DeleteAccount(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStore(t)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId("account_id", testUserID, "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
err := store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
err = store.DeleteAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
}
_, err = store.GetAccountByPeerPubKey("peerkey")
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key")
_, err = store.GetAccountByUser("testuser")
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user")
_, err = store.GetAccountByPeerID("testpeer")
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id")
_, err = store.GetAccountBySetupKey(setupKey.Key)
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key")
_, err = store.GetAccount(account.Id)
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
for _, policy := range account.Policies {
var rules []*PolicyRule
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
}
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
}
}
func TestPostgresql_SavePeerStatus(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
// save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
// save new status of existing peer
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account)
require.NoError(t, err)
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(account.Id)
require.NoError(t, err)
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus.Connected, actual.Connected)
}
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
}
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(hashed)
require.NoError(t, err)
require.Equal(t, id, token)
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}

View File

@@ -75,3 +75,23 @@ func FromError(err error) (s *Error, ok bool) {
}
return nil, false
}
// NewPeerNotFoundError creates a new Error with NotFound type for a missing peer
func NewPeerNotFoundError(peerKey string) error {
return Errorf(NotFound, "peer not found: %s", peerKey)
}
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account
func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey)
}
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
func NewUserNotFoundError(userKey string) error {
return Errorf(NotFound, "user not found: %s", userKey)
}
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
func NewPeerNotRegisteredError() error {
return Errorf(Unauthenticated, "peer is not registered")
}

View File

@@ -2,15 +2,22 @@ package server
import (
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/route"
)
type Store interface {
@@ -20,11 +27,14 @@ type Store interface {
GetAccountByUser(userID string) (*Account, error)
GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(peerKey string) (string, error)
GetAccountIDByUserID(peerKey string) (string, error)
GetAccountIDBySetupKey(peerKey string) (string, error)
GetAccountByPeerID(peerID string) (*Account, error)
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(domain string) (*Account, error)
GetTokenIDByHashedToken(secret string) (string, error)
GetUserByTokenID(tokenID string) (*User, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(account *Account) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
@@ -44,13 +54,18 @@ type Store interface {
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(accountID string) (*Settings, error)
}
type StoreEngine string
const (
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
PostgresStoreEngine StoreEngine = "postgres"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
)
func getStoreEngineFromEnv() StoreEngine {
@@ -61,8 +76,7 @@ func getStoreEngineFromEnv() StoreEngine {
}
value := StoreEngine(strings.ToLower(kind))
if value == FileStoreEngine || value == SqliteStoreEngine {
if value == FileStoreEngine || value == SqliteStoreEngine || value == PostgresStoreEngine {
return value
}
@@ -94,18 +108,60 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S
case SqliteStoreEngine:
log.Info("using SQLite store engine")
return NewSqliteStore(dataDir, metrics)
case PostgresStoreEngine:
log.Info("using Postgres store engine")
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
return NewPostgresqlStore(dsn, metrics)
default:
return nil, fmt.Errorf("unsupported kind of store %s", kind)
}
}
// NewStoreFromJson is only used in tests
func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) {
// migrate migrates the SQLite database to the latest schema
func migrate(db *gorm.DB) error {
migrations := getMigrations()
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
func getMigrations() []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
},
func(db *gorm.DB) error {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
},
}
}
// NewTestStoreFromJson is only used in tests
func NewTestStoreFromJson(dataDir string) (Store, func(), error) {
fstore, err := NewFileStore(dataDir, nil)
if err != nil {
return nil, err
return nil, nil, err
}
cleanUp := func() {}
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
kind := getStoreEngineFromEnv()
if kind == "" {
@@ -115,10 +171,34 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro
switch kind {
case FileStoreEngine:
return fstore, nil
return fstore, cleanUp, nil
case SqliteStoreEngine:
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil)
if err != nil {
return nil, nil, err
}
return store, cleanUp, nil
case PostgresStoreEngine:
cleanUp, err = testutil.CreatePGDB()
if err != nil {
return nil, nil, err
}
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
store, err := NewPostgresqlStoreFromFileStore(fstore, dsn, nil)
if err != nil {
return nil, nil, err
}
return store, cleanUp, nil
default:
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil)
if err != nil {
return nil, nil, err
}
return store, cleanUp, nil
}
}

View File

@@ -5,50 +5,49 @@ import (
"time"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/asyncint64"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter syncint64.Counter
loginRequestsCounter syncint64.Counter
getKeyRequestsCounter syncint64.Counter
activeStreamsGauge asyncint64.Gauge
syncRequestDuration syncint64.Histogram
loginRequestDuration syncint64.Histogram
channelQueueLength syncint64.Histogram
syncRequestsCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, error) {
syncRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.sync.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.login.request.counter", instrument.WithUnit("1"))
if err != nil {
return nil, err
}
getKeyRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.key.request.counter", instrument.WithUnit("1"))
syncRequestsCounter, err := meter.Int64Counter("management.grpc.sync.request.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
activeStreamsGauge, err := meter.AsyncInt64().Gauge("management.grpc.connected.streams", instrument.WithUnit("1"))
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
syncRequestDuration, err := meter.SyncInt64().Histogram("management.grpc.sync.request.duration.ms", instrument.WithUnit("milliseconds"))
getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
loginRequestDuration, err := meter.SyncInt64().Histogram("management.grpc.login.request.duration.ms", instrument.WithUnit("milliseconds"))
activeStreamsGauge, err := meter.Int64ObservableGauge("management.grpc.connected.streams", metric.WithUnit("1"))
if err != nil {
return nil, err
}
syncRequestDuration, err := meter.Int64Histogram("management.grpc.sync.request.duration.ms", metric.WithUnit("milliseconds"))
if err != nil {
return nil, err
}
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms", metric.WithUnit("milliseconds"))
if err != nil {
return nil, err
}
@@ -56,10 +55,10 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
// Then we should be able to extract min, manx, mean and the percentiles.
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
channelQueue, err := meter.SyncInt64().Histogram(
channelQueue, err := meter.Int64Histogram(
"management.grpc.updatechannel.queue",
instrument.WithDescription("Number of update messages in the channel queue"),
instrument.WithUnit("length"),
metric.WithDescription("Number of update messages in the channel queue"),
metric.WithUnit("length"),
)
if err != nil {
return nil, err
@@ -105,14 +104,14 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration)
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error {
return grpcMetrics.meter.RegisterCallback(
[]instrument.Asynchronous{
grpcMetrics.activeStreamsGauge,
},
func(ctx context.Context) {
grpcMetrics.activeStreamsGauge.Observe(ctx, producer())
_, err := grpcMetrics.meter.RegisterCallback(
func(ctx context.Context, observer metric.Observer) error {
observer.ObserveInt64(grpcMetrics.activeStreamsGauge, producer())
return nil
},
grpcMetrics.activeStreamsGauge,
)
return err
}
// UpdateChannelQueueLength update the histogram that keep distribution of the update messages channel queue

View File

@@ -6,13 +6,11 @@ import (
"hash/fnv"
"net/http"
"strings"
time "time"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
const (
@@ -56,51 +54,44 @@ type HTTPMiddleware struct {
meter metric.Meter
ctx context.Context
// all HTTP requests by endpoint & method
httpRequestCounters map[string]syncint64.Counter
httpRequestCounters map[string]metric.Int64Counter
// all HTTP responses by endpoint & method & status code
httpResponseCounters map[string]syncint64.Counter
httpResponseCounters map[string]metric.Int64Counter
// all HTTP requests
totalHTTPRequestsCounter syncint64.Counter
totalHTTPRequestsCounter metric.Int64Counter
// all HTTP responses
totalHTTPResponseCounter syncint64.Counter
totalHTTPResponseCounter metric.Int64Counter
// all HTTP responses by status code
totalHTTPResponseCodeCounters map[int]syncint64.Counter
totalHTTPResponseCodeCounters map[int]metric.Int64Counter
// all HTTP requests durations by endpoint and method
httpRequestDurations map[string]syncint64.Histogram
httpRequestDurations map[string]metric.Int64Histogram
// all HTTP requests durations
totalHTTPRequestDuration syncint64.Histogram
totalHTTPRequestDuration metric.Int64Histogram
}
// NewMetricsMiddleware creates a new HTTPMiddleware
func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) {
totalHTTPRequestsCounter, err := meter.SyncInt64().Counter(
fmt.Sprintf("%s_total", httpRequestCounterPrefix),
instrument.WithUnit("1"))
if err != nil {
return nil, err
}
totalHTTPResponseCounter, err := meter.SyncInt64().Counter(
fmt.Sprintf("%s_total", httpResponseCounterPrefix),
instrument.WithUnit("1"))
totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpRequestCounterPrefix), metric.WithUnit("1"))
if err != nil {
return nil, err
}
totalHTTPRequestDuration, err := meter.SyncInt64().Histogram(
fmt.Sprintf("%s_total", httpRequestDurationPrefix),
instrument.WithUnit("milliseconds"))
totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpResponseCounterPrefix), metric.WithUnit("1"))
if err != nil {
return nil, err
}
totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s_total", httpRequestDurationPrefix), metric.WithUnit("milliseconds"))
if err != nil {
return nil, err
}
return &HTTPMiddleware{
ctx: ctx,
httpRequestCounters: map[string]syncint64.Counter{},
httpResponseCounters: map[string]syncint64.Counter{},
httpRequestDurations: map[string]syncint64.Histogram{},
totalHTTPResponseCodeCounters: map[int]syncint64.Counter{},
httpRequestCounters: map[string]metric.Int64Counter{},
httpResponseCounters: map[string]metric.Int64Counter{},
httpRequestDurations: map[string]metric.Int64Histogram{},
totalHTTPResponseCodeCounters: map[int]metric.Int64Counter{},
meter: meter,
totalHTTPRequestsCounter: totalHTTPRequestsCounter,
totalHTTPResponseCounter: totalHTTPResponseCounter,
@@ -113,28 +104,30 @@ func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddlew
// Creates one request counter and multiple response counters (one per http response status code).
func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error {
meterKey := getRequestCounterKey(endpoint, method)
httpReqCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
httpReqCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
if err != nil {
return err
}
m.httpRequestCounters[meterKey] = httpReqCounter
durationKey := getRequestDurationKey(endpoint, method)
requestDuration, err := m.meter.SyncInt64().Histogram(durationKey, instrument.WithUnit("milliseconds"))
requestDuration, err := m.meter.Int64Histogram(durationKey, metric.WithUnit("milliseconds"))
if err != nil {
return err
}
m.httpRequestDurations[durationKey] = requestDuration
respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503}
for _, code := range respCodes {
meterKey = getResponseCounterKey(endpoint, method, code)
httpRespCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
httpRespCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
if err != nil {
return err
}
m.httpResponseCounters[meterKey] = httpRespCounter
meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code)
totalHTTPResponseCodeCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
totalHTTPResponseCodeCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
if err != nil {
return err
}
@@ -144,19 +137,26 @@ func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method s
return nil
}
func replaceEndpointChars(endpoint string) string {
endpoint = strings.ReplaceAll(endpoint, "/", "_")
endpoint = strings.ReplaceAll(endpoint, "{", "")
endpoint = strings.ReplaceAll(endpoint, "}", "")
return endpoint
}
func getRequestCounterKey(endpoint, method string) string {
return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method)
endpoint = replaceEndpointChars(endpoint)
return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix, endpoint, method)
}
func getRequestDurationKey(endpoint, method string) string {
return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method)
endpoint = replaceEndpointChars(endpoint)
return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix, endpoint, method)
}
func getResponseCounterKey(endpoint, method string, status int) string {
return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix,
strings.ReplaceAll(endpoint, "/", "_"), method, status)
endpoint = replaceEndpointChars(endpoint)
return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix, endpoint, method, status)
}
// Handler logs every request and response and adds the, to metrics.
@@ -201,9 +201,11 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
log.Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status())
if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) {
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), attribute.String("type", "write"))
opts := metric.WithAttributeSet(attribute.NewSet(attribute.String("type", "write")))
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), opts)
} else {
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), attribute.String("type", "read"))
opts := metric.WithAttributeSet(attribute.NewSet(attribute.String("type", "read")))
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), opts)
}
}

View File

@@ -4,64 +4,62 @@ import (
"context"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// IDPMetrics is common IdP metrics
type IDPMetrics struct {
metaUpdateCounter syncint64.Counter
getUserByEmailCounter syncint64.Counter
getAllAccountsCounter syncint64.Counter
createUserCounter syncint64.Counter
deleteUserCounter syncint64.Counter
getAccountCounter syncint64.Counter
getUserByIDCounter syncint64.Counter
authenticateRequestCounter syncint64.Counter
requestErrorCounter syncint64.Counter
requestStatusErrorCounter syncint64.Counter
metaUpdateCounter metric.Int64Counter
getUserByEmailCounter metric.Int64Counter
getAllAccountsCounter metric.Int64Counter
createUserCounter metric.Int64Counter
deleteUserCounter metric.Int64Counter
getAccountCounter metric.Int64Counter
getUserByIDCounter metric.Int64Counter
authenticateRequestCounter metric.Int64Counter
requestErrorCounter metric.Int64Counter
requestStatusErrorCounter metric.Int64Counter
ctx context.Context
}
// NewIDPMetrics creates new IDPMetrics struct and registers common
func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) {
metaUpdateCounter, err := meter.SyncInt64().Counter("management.idp.update.user.meta.counter", instrument.WithUnit("1"))
metaUpdateCounter, err := meter.Int64Counter("management.idp.update.user.meta.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
getUserByEmailCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.email.counter", instrument.WithUnit("1"))
getUserByEmailCounter, err := meter.Int64Counter("management.idp.get.user.by.email.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
getAllAccountsCounter, err := meter.SyncInt64().Counter("management.idp.get.accounts.counter", instrument.WithUnit("1"))
getAllAccountsCounter, err := meter.Int64Counter("management.idp.get.accounts.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
createUserCounter, err := meter.SyncInt64().Counter("management.idp.create.user.counter", instrument.WithUnit("1"))
createUserCounter, err := meter.Int64Counter("management.idp.create.user.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1"))
deleteUserCounter, err := meter.Int64Counter("management.idp.delete.user.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
getAccountCounter, err := meter.Int64Counter("management.idp.get.account.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
getUserByIDCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.id.counter", instrument.WithUnit("1"))
getUserByIDCounter, err := meter.Int64Counter("management.idp.get.user.by.id.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
authenticateRequestCounter, err := meter.SyncInt64().Counter("management.idp.authenticate.request.counter", instrument.WithUnit("1"))
authenticateRequestCounter, err := meter.Int64Counter("management.idp.authenticate.request.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
requestErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.error.counter", instrument.WithUnit("1"))
requestErrorCounter, err := meter.Int64Counter("management.idp.request.error.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}
requestStatusErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.status.error.counter", instrument.WithUnit("1"))
requestStatusErrorCounter, err := meter.Int64Counter("management.idp.request.status.error.counter", metric.WithUnit("1"))
if err != nil {
return nil, err
}

View File

@@ -5,39 +5,37 @@ import (
"time"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// StoreMetrics represents all metrics related to the Store
type StoreMetrics struct {
globalLockAcquisitionDurationMicro syncint64.Histogram
globalLockAcquisitionDurationMs syncint64.Histogram
persistenceDurationMicro syncint64.Histogram
persistenceDurationMs syncint64.Histogram
globalLockAcquisitionDurationMicro metric.Int64Histogram
globalLockAcquisitionDurationMs metric.Int64Histogram
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
ctx context.Context
}
// NewStoreMetrics creates an instance of StoreMetrics
func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, error) {
globalLockAcquisitionDurationMicro, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.micro",
instrument.WithUnit("microseconds"))
globalLockAcquisitionDurationMicro, err := meter.Int64Histogram("management.store.global.lock.acquisition.duration.micro",
metric.WithUnit("microseconds"))
if err != nil {
return nil, err
}
globalLockAcquisitionDurationMs, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.ms")
globalLockAcquisitionDurationMs, err := meter.Int64Histogram("management.store.global.lock.acquisition.duration.ms")
if err != nil {
return nil, err
}
persistenceDurationMicro, err := meter.SyncInt64().Histogram("management.store.persistence.duration.micro",
instrument.WithUnit("microseconds"))
persistenceDurationMicro, err := meter.Int64Histogram("management.store.persistence.duration.micro",
metric.WithUnit("microseconds"))
if err != nil {
return nil, err
}
persistenceDurationMs, err := meter.SyncInt64().Histogram("management.store.persistence.duration.ms")
persistenceDurationMs, err := meter.Int64Histogram("management.store.persistence.duration.ms")
if err != nil {
return nil, err
}

View File

@@ -6,60 +6,59 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/metric/instrument/syncint64"
)
// UpdateChannelMetrics represents all metrics related to the UpdateChannel
type UpdateChannelMetrics struct {
createChannelDurationMicro syncint64.Histogram
closeChannelDurationMicro syncint64.Histogram
closeChannelsDurationMicro syncint64.Histogram
closeChannels syncint64.Histogram
sendUpdateDurationMicro syncint64.Histogram
getAllConnectedPeersDurationMicro syncint64.Histogram
getAllConnectedPeers syncint64.Histogram
hasChannelDurationMicro syncint64.Histogram
createChannelDurationMicro metric.Int64Histogram
closeChannelDurationMicro metric.Int64Histogram
closeChannelsDurationMicro metric.Int64Histogram
closeChannels metric.Int64Histogram
sendUpdateDurationMicro metric.Int64Histogram
getAllConnectedPeersDurationMicro metric.Int64Histogram
getAllConnectedPeers metric.Int64Histogram
hasChannelDurationMicro metric.Int64Histogram
ctx context.Context
}
// NewUpdateChannelMetrics creates an instance of UpdateChannel
func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateChannelMetrics, error) {
createChannelDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.create.duration.micro")
createChannelDurationMicro, err := meter.Int64Histogram("management.updatechannel.create.duration.micro")
if err != nil {
return nil, err
}
closeChannelDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.close.one.duration.micro")
closeChannelDurationMicro, err := meter.Int64Histogram("management.updatechannel.close.one.duration.micro")
if err != nil {
return nil, err
}
closeChannelsDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.close.multiple.duration.micro")
closeChannelsDurationMicro, err := meter.Int64Histogram("management.updatechannel.close.multiple.duration.micro")
if err != nil {
return nil, err
}
closeChannels, err := meter.SyncInt64().Histogram("management.updatechannel.close.multiple.channels")
closeChannels, err := meter.Int64Histogram("management.updatechannel.close.multiple.channels")
if err != nil {
return nil, err
}
sendUpdateDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.send.duration.micro")
sendUpdateDurationMicro, err := meter.Int64Histogram("management.updatechannel.send.duration.micro")
if err != nil {
return nil, err
}
getAllConnectedPeersDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.get.all.duration.micro")
getAllConnectedPeersDurationMicro, err := meter.Int64Histogram("management.updatechannel.get.all.duration.micro")
if err != nil {
return nil, err
}
getAllConnectedPeers, err := meter.SyncInt64().Histogram("management.updatechannel.get.all.peers")
getAllConnectedPeers, err := meter.Int64Histogram("management.updatechannel.get.all.peers")
if err != nil {
return nil, err
}
hasChannelDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.haschannel.duration.micro")
hasChannelDurationMicro, err := meter.Int64Histogram("management.updatechannel.haschannel.duration.micro")
if err != nil {
return nil, err
}
@@ -80,7 +79,8 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
// CountCreateChannelDuration counts the duration of the CreateChannel method,
// closed indicates if existing channel was closed before creation of a new one
func (metrics *UpdateChannelMetrics) CountCreateChannelDuration(duration time.Duration, closed bool) {
metrics.createChannelDurationMicro.Record(metrics.ctx, duration.Microseconds(), attribute.Bool("closed", closed))
opts := metric.WithAttributeSet(attribute.NewSet(attribute.Bool("closed", closed)))
metrics.createChannelDurationMicro.Record(metrics.ctx, duration.Microseconds(), opts)
}
// CountCloseChannelDuration counts the duration of the CloseChannel method
@@ -97,8 +97,8 @@ func (metrics *UpdateChannelMetrics) CountCloseChannelsDuration(duration time.Du
// CountSendUpdateDuration counts the duration of the SendUpdate method
// found indicates if peer had channel, dropped indicates if the message was dropped due channel buffer overload
func (metrics *UpdateChannelMetrics) CountSendUpdateDuration(duration time.Duration, found, dropped bool) {
attrs := []attribute.KeyValue{attribute.Bool("found", found), attribute.Bool("dropped", dropped)}
metrics.sendUpdateDurationMicro.Record(metrics.ctx, duration.Microseconds(), attrs...)
opts := metric.WithAttributeSet(attribute.NewSet(attribute.Bool("found", found), attribute.Bool("dropped", dropped)))
metrics.sendUpdateDurationMicro.Record(metrics.ctx, duration.Microseconds(), opts)
}
// CountGetAllConnectedPeersDuration counts the duration of the GetAllConnectedPeers method and the number of peers have been returned

View File

@@ -0,0 +1,45 @@
//go:build !ios
// +build !ios
package testutil
import (
"context"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
func CreatePGDB() (func(), error) {
ctx := context.Background()
c, err := postgres.RunContainer(ctx,
testcontainers.WithImage("postgres:alpine"),
postgres.WithDatabase("test"),
postgres.WithUsername("postgres"),
postgres.WithPassword("postgres"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).WithStartupTimeout(15*time.Second)),
)
if err != nil {
return nil, err
}
cleanup := func() {
timeout := 10 * time.Second
err = c.Stop(ctx, &timeout)
if err != nil {
log.Warnf("failed to stop container: %s", err)
}
}
talksConn, err := c.ConnectionString(ctx)
if err != nil {
return cleanup, err
}
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
}

View File

@@ -0,0 +1,6 @@
//go:build ios
// +build ios
package testutil
func CreatePGDB() (func(), error) { return func() {}, nil }

View File

@@ -910,8 +910,10 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) {
start := time.Now()
unlock := am.Store.AcquireGlobalLock()
defer unlock()
log.Debugf("Acquired global lock in %s for user %s", time.Since(start), userID)
lowerDomain := strings.ToLower(domain)

View File

@@ -39,6 +39,7 @@ const (
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -76,6 +77,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
Id: mockTargetUserId,
@@ -97,6 +99,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
Id: mockTargetUserId,
@@ -122,6 +125,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -140,6 +144,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -158,6 +163,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
func TestUser_DeletePAT(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
@@ -190,6 +196,7 @@ func TestUser_DeletePAT(t *testing.T) {
func TestUser_GetPAT(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
@@ -221,6 +228,7 @@ func TestUser_GetPAT(t *testing.T) {
func TestUser_GetAllPATs(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
@@ -322,6 +330,7 @@ func validateStruct(s interface{}) (err error) {
func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -359,6 +368,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -397,6 +407,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -421,6 +432,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
func TestUser_InviteNewUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -549,6 +561,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -569,6 +582,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
targetId := "user2"
@@ -650,6 +664,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
err := store.SaveAccount(account)
@@ -678,6 +693,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
@@ -790,6 +806,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
externalUser := &User{
Id: "externalUser",
@@ -853,6 +870,7 @@ func TestUser_IsAdmin(t *testing.T) {
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,
@@ -880,6 +898,8 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t)
defer store.Close()
account := newAccountWithId(mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
Id: mockServiceUserID,

419
relay/client/client.go Normal file
View File

@@ -0,0 +1,419 @@
package client
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 8820
serverResponseTimeout = 8 * time.Second
)
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
type Msg struct {
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
type connContainer struct {
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
}
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
return &connContainer{
conn: conn,
messages: messages,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.messages <- msg
}
func (cc *connContainer) close() {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
close(cc.messages)
cc.closed = true
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established.
type Client struct {
log *log.Entry
parentCtx context.Context
ctxCancel context.CancelFunc
serverAddress string
hashedID []byte
bufPool *sync.Pool
relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool
mu sync.Mutex
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup
remoteAddr net.Addr
onDisconnectListener func()
listenerMutex sync.Mutex
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID)
return &Client{
log: log.WithField("client_id", hashedStringId),
parentCtx: ctx,
ctxCancel: func() {},
serverAddress: serverAddress,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[string]*connContainer),
}
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect() error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.serviceIsRunning {
return nil
}
err := c.connect()
if err != nil {
return err
}
c.serviceIsRunning = true
var ctx context.Context
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
context.AfterFunc(ctx, func() {
cErr := c.close(false)
if cErr != nil {
log.Errorf("failed to close relay connection: %s", cErr)
}
})
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
return nil
}
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately.
// todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.serviceIsRunning {
return nil, fmt.Errorf("relay connection is not established")
}
hashedID, hashedStringID := messages.HashID(dstPeerID)
log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2)
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
return conn, nil
}
// RelayRemoteAddress returns the IP address of the relay server. It could change after the close and reopen the connection.
func (c *Client) RelayRemoteAddress() (net.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.remoteAddr == nil {
return nil, fmt.Errorf("relay connection is not established")
}
return c.remoteAddr, nil
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func()) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
}
func (c *Client) HasConns() bool {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.conns) > 0
}
// Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error {
return c.close(false)
}
func (c *Client) close(byServer bool) error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
return nil
}
c.serviceIsRunning = false
c.closeAllConns()
if !byServer {
c.writeCloseMsg()
err = c.relayConn.Close()
}
c.mu.Unlock()
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed with: %s", c.serverAddress)
c.ctxCancel()
return err
}
func (c *Client) connect() error {
conn, err := ws.Dial(c.serverAddress)
if err != nil {
return err
}
c.relayConn = conn
err = c.handShake()
if err != nil {
cErr := conn.Close()
if cErr != nil {
log.Errorf("failed to close connection: %s", cErr)
}
c.relayConn = nil
return err
}
c.remoteAddr = conn.RemoteAddr()
return nil
}
func (c *Client) handShake() error {
defer func() {
err := c.relayConn.SetReadDeadline(time.Time{})
if err != nil {
log.Errorf("failed to reset read deadline: %s", err)
}
}()
msg, err := messages.MarshalHelloMsg(c.hashedID)
if err != nil {
log.Errorf("failed to marshal hello message: %s", err)
return err
}
_, err = c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to send hello message: %s", err)
return err
}
err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout))
if err != nil {
log.Errorf("failed to set read deadline: %s", err)
return err
}
buf := make([]byte, 1500) // todo: optimise buffer size
n, err := c.relayConn.Read(buf)
if err != nil {
log.Errorf("failed to read hello response: %s", err)
return err
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
log.Errorf("failed to determine message type: %s", err)
return err
}
if msgType != messages.MsgTypeHelloResponse {
log.Errorf("unexpected message type: %s", msgType)
return fmt.Errorf("unexpected message type")
}
return nil
}
func (c *Client) readLoop(relayConn net.Conn) {
var (
errExit error
n int
closedByServer bool
)
for {
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.mu.Lock()
if c.serviceIsRunning {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
goto Exit
}
msgType, err := messages.DetermineServerMsgType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
continue
}
switch msgType {
case messages.MsgTypeTransport:
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
if err != nil {
c.log.Errorf("failed to parse transport message: %v", err)
continue
}
stringID := messages.HashIDToString(peerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
goto Exit
}
container, ok := c.conns[stringID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", stringID)
continue
}
container.writeMsg(Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload})
case messages.MsgClose:
closedByServer = true
log.Debugf("relay connection close by server")
goto Exit
}
}
Exit:
c.notifyDisconnected()
c.wgReadLoop.Done()
_ = c.close(closedByServer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
c.mu.Lock()
// conn, ok := c.conns[id]
_, ok := c.conns[id]
c.mu.Unlock()
if !ok {
return 0, io.EOF
}
/*
if conn != clientRef {
return 0, io.EOF
}
*/
msg := messages.MarshalTransportMsg(dstID, payload)
n, err := c.relayConn.Write(msg)
if err != nil {
log.Errorf("failed to write transport message: %s", err)
}
return n, err
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
container.close()
}
c.conns = make(map[string]*connContainer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) closeConn(id string) error {
c.mu.Lock()
defer c.mu.Unlock()
container, ok := c.conns[id]
if !ok {
return fmt.Errorf("connection already closed")
}
container.close()
delete(c.conns, id)
return nil
}
func (c *Client) onDisconnect() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
c.onDisconnectListener()
}
func (c *Client) notifyDisconnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener()
}
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}

523
relay/client/client_test.go Normal file
View File

@@ -0,0 +1,523 @@
package client
import (
"context"
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
func TestClient(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder")
err = clientPlaceHolder.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
clientBob := NewClient(ctx, addr, "bob")
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()
connAliceToBob, err := clientAlice.OpenConn("bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn("alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
log.Debugf("alice sent message to bob")
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
log.Debugf("on new message from alice to bob")
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestRegistration(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
_ = srv.Close()
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}
func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
err = clientAlice.Connect()
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
log.Debugf("%s", err)
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}
func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, idAlice)
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(ctx, addr, idBob)
err = clientBob.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
clientBob := NewClient(ctx, addr, "bob")
err = clientBob.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chBob, err := clientBob.OpenConn("alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client Alice")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(ctx, addr, "alice")
err = clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chAlice, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err != nil {
t.Errorf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := chAlice.Read(buf)
if err != nil {
t.Errorf("failed to read from channel: %s", err)
}
if testString != string(buf[:n]) {
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestCloseConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
log.Infof("closing server")
err := srv.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
clientAlice := NewClient(ctx, addr, "alice")
err := clientAlice.Connect()
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn("bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = clientAlice.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func() {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(ctx, addr1, idAlice)
err := relayClient.Connect()
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn("bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Close()
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}

67
relay/client/conn.go Normal file
View File

@@ -0,0 +1,67 @@
package client
import (
"io"
"net"
"time"
)
type Conn struct {
client *Client
dstID []byte
dstStringID string
messageChan chan Msg
}
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
c := &Conn{
client: client,
dstID: dstID,
dstStringID: dstStringID,
messageChan: messageChan,
}
return c
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c.dstStringID, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, io.EOF
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {
return c.client.closeConn(c.dstStringID)
}
func (c *Conn) LocalAddr() net.Addr {
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.client.relayConn.RemoteAddr()
}
func (c *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (c *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,52 @@
package quic
import (
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type Conn struct {
quic.Stream
qConn quic.Connection
}
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
return &Conn{
Stream: stream,
qConn: qConn,
}
}
func (q *Conn) Write(b []byte) (n int, err error) {
log.Debugf("writing: %d, %x\n", len(b), b)
n, err = q.Stream.Write(b)
if n != len(b) {
log.Errorf("failed to write out the full message")
}
return
}
func (q *Conn) Close() error {
err := q.Stream.Close()
if err != nil {
log.Errorf("failed to close stream: %s", err)
return err
}
err = q.qConn.CloseWithError(0, "")
if err != nil {
log.Errorf("failed to close connection: %s", err)
return err
}
return err
}
func (c *Conn) LocalAddr() net.Addr {
return c.qConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.qConn.RemoteAddr()
}

View File

@@ -0,0 +1,32 @@
package quic
import (
"context"
"crypto/tls"
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
func Dial(address string) (net.Conn, error) {
tlsConf := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"quic-echo-example"},
}
qConn, err := quic.DialAddr(context.Background(), address, tlsConf, &quic.Config{
EnableDatagrams: true,
})
if err != nil {
log.Errorf("dial quic address %s failed: %s", address, err)
return nil, err
}
stream, err := qConn.OpenStreamSync(context.Background())
if err != nil {
return nil, err
}
conn := NewConn(stream, qConn)
return conn, nil
}

View File

@@ -0,0 +1,7 @@
package tcp
import "net"
func Dial(address string) (net.Conn, error) {
return net.Dial("tcp", address)
}

View File

@@ -0,0 +1,14 @@
package udp
import (
"net"
)
func Dial(address string) (net.Conn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, err
}
return net.DialUDP("udp", nil, udpAddr)
}

View File

@@ -0,0 +1,59 @@
package ws
import (
"fmt"
"net"
"sync"
"time"
"github.com/gorilla/websocket"
)
type Conn struct {
*websocket.Conn
mu sync.Mutex
}
func NewConn(wsConn *websocket.Conn) net.Conn {
return &Conn{
Conn: wsConn,
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, r, err := c.NextReader()
if err != nil {
return 0, err
}
if t != websocket.BinaryMessage {
return 0, fmt.Errorf("unexpected message type")
}
return r.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.mu.Lock()
err := c.WriteMessage(websocket.BinaryMessage, b)
c.mu.Unlock()
return len(b), err
}
func (c *Conn) SetDeadline(t time.Time) error {
errR := c.SetReadDeadline(t)
errW := c.SetWriteDeadline(t)
if errR != nil {
return errR
}
if errW != nil {
return errW
}
return nil
}
func (c *Conn) Close() error {
return c.Conn.Close()
}

View File

@@ -0,0 +1,22 @@
package ws
import (
"fmt"
"net"
"time"
"github.com/gorilla/websocket"
)
func Dial(address string) (net.Conn, error) {
addr := fmt.Sprintf("ws://" + address)
wsDialer := websocket.Dialer{
HandshakeTimeout: 3 * time.Second,
}
wsConn, _, err := wsDialer.Dial(addr, nil)
if err != nil {
return nil, err
}
conn := NewConn(wsConn)
return conn, nil
}

View File

@@ -0,0 +1,79 @@
package wsnhooyr
import (
"context"
"fmt"
"net"
"time"
"nhooyr.io/websocket"
)
type Conn struct {
*websocket.Conn
ctx context.Context
}
func NewConn(wsConn *websocket.Conn) net.Conn {
return &Conn{
Conn: wsConn,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
return 0, err
}
if t != websocket.MessageBinary {
return 0, fmt.Errorf("unexpected message type")
}
return ioReader.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return len(b), err
}
func (c *Conn) RemoteAddr() net.Addr {
// todo: implement me
return nil
}
func (c *Conn) LocalAddr() net.Addr {
// todo: implement me
return nil
}
func (c *Conn) SetReadDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
// todo: implement me
return nil
}
func (c *Conn) SetDeadline(t time.Time) error {
// todo: implement me
errR := c.SetReadDeadline(t)
errW := c.SetWriteDeadline(t)
if errR != nil {
return errR
}
if errW != nil {
return errW
}
return nil
}
func (c *Conn) Close() error {
return c.Conn.CloseNow()
}

View File

@@ -0,0 +1,22 @@
package wsnhooyr
import (
"context"
"fmt"
"net"
"nhooyr.io/websocket"
)
func Dial(address string) (net.Conn, error) {
addr := fmt.Sprintf("ws://" + address)
wsConn, _, err := websocket.Dial(context.Background(), addr, nil)
if err != nil {
return nil, err
}
conn := NewConn(wsConn)
return conn, nil
}

33
relay/client/guard.go Normal file
View File

@@ -0,0 +1,33 @@
package client
import (
"context"
"time"
)
var (
reconnectingTimeout = 5 * time.Second
)
type Guard struct {
ctx context.Context
relayClient *Client
}
func NewGuard(context context.Context, relayClient *Client) *Guard {
g := &Guard{
ctx: context,
relayClient: relayClient,
}
return g
}
func (g *Guard) OnDisconnected() {
select {
case <-time.After(time.Second):
_ = g.relayClient.Connect()
case <-g.ctx.Done():
return
}
}

208
relay/client/manager.go Normal file
View File

@@ -0,0 +1,208 @@
package client
import (
"context"
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
relayCleanupInterval = 60 * time.Second
)
// RelayTrack hold the relay clients for the foregin relay servers.
// With the mutex can ensure we can open new connection in case the relay connection has been established with
// the relay server.
type RelayTrack struct {
sync.RWMutex
relayClient *Client
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
}
// Manager is a manager for the relay client. It establish one persistent connection to the given relay server. In case
// of network error the manager will try to reconnect to the server.
// The manager also manage temproary relay connection. If a client wants to communicate with an another client on a
// different relay server, the manager will establish a new connection to the relay server. The connection with these
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
srvAddress string
peerID string
relayClient *Client
reconnectGuard *Guard
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
}
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
return &Manager{
ctx: ctx,
srvAddress: serverAddress,
peerID: peerID,
relayClients: make(map[string]*RelayTrack),
}
}
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop.
// todo: consider to return an error if the initial connection to the relay server is not established.
func (m *Manager) Serve() {
m.relayClient = NewClient(m.ctx, m.srvAddress, m.peerID)
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
m.relayClient.SetOnDisconnectListener(m.reconnectGuard.OnDisconnected)
err := m.relayClient.Connect()
if err != nil {
log.Errorf("failed to connect to relay server, keep try to reconnect: %s", err)
return
}
m.startCleanupLoop()
return
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
if m.relayClient == nil {
return nil, fmt.Errorf("relay client not connected")
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return nil, err
}
if !foreign {
log.Debugf("open connection to permanent server: %s", peerKey)
return m.relayClient.OpenConn(peerKey)
} else {
log.Debugf("open connection to foreign server: %s", serverAddress)
return m.openConnVia(serverAddress, peerKey)
}
}
// RelayAddress returns the address of the permanent relay server. It could change if the network connection is lost.
// This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayAddress() (net.Addr, error) {
if m.relayClient == nil {
return nil, fmt.Errorf("relay client not connected")
}
return m.relayClient.RelayRemoteAddress()
}
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
m.relayClientsMutex.RUnlock()
// if not, establish a new connection but check it again (because changed the lock type) before starting the
// connection
m.relayClientsMutex.Lock()
rt, ok = m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
return rt.relayClient.OpenConn(peerKey)
}
// create a new relay client and store it in the relayClients map
rt = NewRelayTrack()
rt.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
err := relayClient.Connect()
if err != nil {
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(func() {
m.deleteRelayConn(serverAddress)
})
rt.relayClient = relayClient
rt.Unlock()
conn, err := relayClient.OpenConn(peerKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (m *Manager) deleteRelayConn(address string) {
log.Infof("deleting relay client for %s", address)
m.relayClientsMutex.Lock()
delete(m.relayClients, address)
m.relayClientsMutex.Unlock()
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.RelayRemoteAddress()
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
return rAddr.String() != address, nil
}
func (m *Manager) startCleanupLoop() {
if m.ctx.Err() != nil {
return
}
ticker := time.NewTicker(relayCleanupInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}
}()
}
func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
for addr, rt := range m.relayClients {
rt.Lock()
if rt.relayClient.HasConns() {
rt.Unlock()
continue
}
rt.relayClient.SetOnDisconnectListener(nil)
go func() {
_ = rt.relayClient.Close()
}()
log.Debugf("clean up relay client: %s", addr)
delete(m.relayClients, addr)
rt.Unlock()
}
}

View File

@@ -0,0 +1,271 @@
package client
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server"
)
func TestForeignConn(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr1, idAlice)
clientAlice.Serve()
idBob := "bob"
log.Debugf("connect by bob")
clientBob := NewManager(mCtx, addr2, idBob)
clientBob.Serve()
bobsSrvAddr, err := clientBob.RelayAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
idAlice := "alice"
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr.Serve()
conn, err := mgr.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
}
func TestForeginAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
addr1 := "localhost:1234"
srv1 := server.NewServer()
go func() {
t.Log("binding server 1.")
err := srv1.Listen(addr1)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
t.Logf("closing server 1.")
err := srv1.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
addr2 := "localhost:2234"
srv2 := server.NewServer()
go func() {
t.Log("binding server 2.")
err := srv2.Listen(addr2)
if err != nil {
t.Fatalf("failed to bind server: %s", err)
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Close()
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, addr1, idAlice)
mgr.Serve()
t.Log("open connection to another peer")
conn, err := mgr.OpenConn(addr2, "anotherpeer")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("close conn")
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
addr := "localhost:1234"
srv := server.NewServer()
go func() {
err := srv.Listen(addr)
if err != nil {
t.Errorf("failed to bind server: %s", err)
}
}()
defer func() {
err := srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, addr, "alice")
clientAlice.Serve()
ra, err := clientAlice.RelayAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ra.String(), "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ra.String(), "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}

40
relay/cmd/main.go Normal file
View File

@@ -0,0 +1,40 @@
package main
import (
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/util"
)
func init() {
util.InitLog("trace", "console")
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
_ = <-osSigs
}
func main() {
address := "10.145.236.1:1235"
srv := server.NewServer()
err := srv.Listen(address)
if err != nil {
log.Errorf("failed to bind server: %s", err)
os.Exit(1)
}
waitForExitSignal()
err = srv.Close()
if err != nil {
log.Errorf("failed to close server: %s", err)
os.Exit(1)
}
}

20
relay/messages/id.go Normal file
View File

@@ -0,0 +1,20 @@
package messages
import (
"crypto/sha256"
"encoding/base64"
)
const (
IDSize = sha256.Size
)
func HashID(peerID string) ([]byte, string) {
idHash := sha256.Sum256([]byte(peerID))
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
return idHash[:], idHashString
}
func HashIDToString(idHash []byte) string {
return base64.StdEncoding.EncodeToString(idHash[:])
}

137
relay/messages/message.go Normal file
View File

@@ -0,0 +1,137 @@
package messages
import (
"fmt"
log "github.com/sirupsen/logrus"
)
const (
MsgTypeHello MsgType = 0
MsgTypeHelloResponse MsgType = 1
MsgTypeTransport MsgType = 2
MsgClose MsgType = 3
)
var (
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
)
type MsgType byte
func (m MsgType) String() string {
switch m {
case MsgTypeHello:
return "hello"
case MsgTypeHelloResponse:
return "hello response"
case MsgTypeTransport:
return "transport"
case MsgClose:
return "close"
default:
return "unknown"
}
}
func DetermineClientMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHello:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
case MsgClose:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
}
}
func DetermineServerMsgType(msg []byte) (MsgType, error) {
// todo: validate magic byte
msgType := MsgType(msg[0])
switch msgType {
case MsgTypeHelloResponse:
return msgType, nil
case MsgTypeTransport:
return msgType, nil
case MsgClose:
return msgType, nil
default:
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
}
}
// MarshalHelloMsg initial hello message
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length")
}
msg := make([]byte, 1, 1+len(peerID))
msg[0] = byte(MsgTypeHello)
msg = append(msg, peerID...)
return msg, nil
}
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
if len(msg) < 2 {
return nil, fmt.Errorf("invalid 'hello' messge")
}
return msg[1:], nil
}
func MarshalHelloResponse() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgTypeHelloResponse)
return msg
}
// Close message
func MarshalCloseMsg() []byte {
msg := make([]byte, 1)
msg[0] = byte(MsgClose)
return msg
}
// Transport message
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
if len(peerID) != IDSize {
return nil
}
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
msg[0] = byte(MsgTypeTransport)
copy(msg[1:], peerID)
msg = append(msg, payload...)
return msg
}
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
return nil, nil, ErrInvalidMessageLength
}
return buf[1:headerSize], buf[headerSize:], nil
}
func UnmarshalTransportID(buf []byte) ([]byte, error) {
headerSize := 1 + IDSize
if len(buf) < headerSize {
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
return nil, ErrInvalidMessageLength
}
return buf[1:headerSize], nil
}
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < 1+len(peerID) {
return ErrInvalidMessageLength
}
copy(msg[1:], peerID)
return nil
}

View File

@@ -0,0 +1,9 @@
package listener
import "net"
type Listener interface {
Listen(func(conn net.Conn)) error
Close() error
WaitForExitAcceptedConns()
}

View File

@@ -0,0 +1,36 @@
package quic
import (
"net"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)
type QuicConn struct {
quic.Stream
qConn quic.Connection
}
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
return &QuicConn{
Stream: stream,
qConn: qConn,
}
}
func (q QuicConn) Write(b []byte) (n int, err error) {
n, err = q.Stream.Write(b)
if n != len(b) {
log.Errorf("failed to write out the full message")
}
return
}
func (q QuicConn) LocalAddr() net.Addr {
return q.qConn.LocalAddr()
}
func (q QuicConn) RemoteAddr() net.Addr {
return q.qConn.RemoteAddr()
}

View File

@@ -0,0 +1,111 @@
package quic
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math/big"
"net"
"sync"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
onAcceptFn func(conn net.Conn)
listener *quic.Listener
quit chan struct{}
wg sync.WaitGroup
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
ql, err := quic.ListenAddr(l.address, generateTLSConfig(), &quic.Config{
EnableDatagrams: true,
})
if err != nil {
return err
}
l.listener = ql
l.quit = make(chan struct{})
log.Infof("quic server is listening on address: %s", l.address)
l.wg.Add(1)
go l.acceptLoop(onAcceptFn)
<-l.quit
return nil
}
func (l *Listener) Close() error {
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) acceptLoop(acceptFn func(conn net.Conn)) {
defer l.wg.Done()
for {
qConn, err := l.listener.Accept(context.Background())
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
log.Infof("new connection from: %s", qConn.RemoteAddr())
stream, err := qConn.AcceptStream(context.Background())
if err != nil {
log.Errorf("failed to open stream: %s", err)
continue
}
conn := NewConn(stream, qConn)
go acceptFn(conn)
}
}
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"quic-echo-example"},
}
}

View File

@@ -0,0 +1,80 @@
package tcp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
// Listener
// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, adn
// it does not handle framing.
type Listener struct {
address string
onAcceptFn func(conn net.Conn)
wg sync.WaitGroup
quit chan struct{}
listener net.Listener
lock sync.Mutex
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
li, err := net.Listen("tcp", l.address)
if err != nil {
log.Errorf("failed to listen on address: %s, %s", l.address, err)
l.lock.Unlock()
return err
}
log.Debugf("TCP server is listening on address: %s", l.address)
l.listener = li
l.wg.Add(1)
go l.acceptLoop()
l.lock.Unlock()
<-l.quit
return nil
}
// Close todo: prevent multiple call (do not close two times the channel)
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
return err
}
func (l *Listener) acceptLoop() {
defer l.wg.Done()
for {
conn, err := l.listener.Accept()
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
go l.onAcceptFn(conn)
}
}

View File

@@ -0,0 +1,68 @@
package udp
import (
"io"
"net"
"time"
)
type UDPConn struct {
*net.UDPConn
addr *net.UDPAddr
msgChannel chan []byte
}
func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *UDPConn {
return &UDPConn{
UDPConn: conn,
addr: addr,
msgChannel: make(chan []byte),
}
}
func (u *UDPConn) Read(b []byte) (n int, err error) {
msg, ok := <-u.msgChannel
if !ok {
return 0, io.EOF
}
n = copy(b, msg)
return n, nil
}
func (u *UDPConn) Write(b []byte) (n int, err error) {
return u.UDPConn.WriteTo(b, u.addr)
}
func (u *UDPConn) Close() error {
//TODO implement me
//panic("implement me")
return nil
}
func (u *UDPConn) LocalAddr() net.Addr {
return u.UDPConn.LocalAddr()
}
func (u *UDPConn) RemoteAddr() net.Addr {
return u.addr
}
func (u *UDPConn) SetDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("implement me")
}
func (u *UDPConn) onNewMsg(b []byte) {
u.msgChannel <- b
}

View File

@@ -0,0 +1,109 @@
package udp
import (
"net"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener"
)
type Listener struct {
address string
conns map[string]*UDPConn
onAcceptFn func(conn net.Conn)
listener *net.UDPConn
wg sync.WaitGroup
quit chan struct{}
lock sync.Mutex
}
func (l *Listener) WaitForExitAcceptedConns() {
l.wg.Wait()
return
}
func NewListener(address string) listener.Listener {
return &Listener{
address: address,
conns: make(map[string]*UDPConn),
}
}
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
l.lock.Lock()
l.onAcceptFn = onAcceptFn
l.quit = make(chan struct{})
addr, err := net.ResolveUDPAddr("udp", l.address)
if err != nil {
log.Errorf("invalid listen address '%s': %s", l.address, err)
l.lock.Unlock()
return err
}
li, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("%s", err)
l.lock.Unlock()
return err
}
log.Debugf("udp server is listening on address: %s", addr.String())
l.listener = li
l.wg.Add(1)
go l.readLoop()
l.lock.Unlock()
<-l.quit
return nil
}
func (l *Listener) Close() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.listener == nil {
return nil
}
log.Infof("closing UDP listener")
close(l.quit)
err := l.listener.Close()
l.wg.Wait()
l.listener = nil
return err
}
func (l *Listener) readLoop() {
defer l.wg.Done()
for {
buf := make([]byte, 1500)
n, addr, err := l.listener.ReadFromUDP(buf)
if err != nil {
select {
case <-l.quit:
return
default:
log.Errorf("failed to accept connection: %s", err)
continue
}
}
pConn, ok := l.conns[addr.String()]
if ok {
pConn.onNewMsg(buf[:n])
continue
}
pConn = NewConn(l.listener, addr)
log.Infof("new connection from: %s", pConn.RemoteAddr())
l.conns[addr.String()] = pConn
go l.onAcceptFn(pConn)
pConn.onNewMsg(buf[:n])
}
}

Some files were not shown because too many files have changed in this diff Show More