mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 16:12:26 -04:00
Compare commits
70 Commits
fix/proxy_
...
feature/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38f2a59d1b | ||
|
|
9504012920 | ||
|
|
5e93d117cf | ||
|
|
8c70b7d7ff | ||
|
|
ed8def4d9b | ||
|
|
1e115e3893 | ||
|
|
deffe037aa | ||
|
|
fed9e587af | ||
|
|
983d7bafbe | ||
|
|
a40d4d2f32 | ||
|
|
4da29451d0 | ||
|
|
15818b72c6 | ||
|
|
0556dc1860 | ||
|
|
2b369cd28f | ||
|
|
9d44a476c6 | ||
|
|
9b3449753e | ||
|
|
456629811b | ||
|
|
57ddb5f262 | ||
|
|
4ced07dd8d | ||
|
|
3430b81622 | ||
|
|
fd4ad15c83 | ||
|
|
c311d0d19e | ||
|
|
521f7dd39f | ||
|
|
f9ec0a9a2e | ||
|
|
012235ff12 | ||
|
|
4ff069a102 | ||
|
|
7cc3964a4d | ||
|
|
6d627f1923 | ||
|
|
076ce69a24 | ||
|
|
f176807ebe | ||
|
|
d4c47eaf8a | ||
|
|
645a1f31a7 | ||
|
|
b4aa7e50f9 | ||
|
|
d35a79d3b5 | ||
|
|
6a2929011d | ||
|
|
173ca25dac | ||
|
|
e877c9d6c1 | ||
|
|
7a1c96ebf4 | ||
|
|
41fe9f84ec | ||
|
|
d13fb0e379 | ||
|
|
f3214527ea | ||
|
|
69048bfd34 | ||
|
|
29a2d93873 | ||
|
|
6b01b0020e | ||
|
|
9d3db68805 | ||
|
|
2e315311e0 | ||
|
|
36b2cd16cc | ||
|
|
67e2185964 | ||
|
|
89149dc6f4 | ||
|
|
5a1f8f13a2 | ||
|
|
e71059d245 | ||
|
|
91fa2e20a0 | ||
|
|
61034aaf4d | ||
|
|
0a05f8b4d4 | ||
|
|
e82c0a55a3 | ||
|
|
13eb457132 | ||
|
|
b8717b8956 | ||
|
|
1c9c9ae47e | ||
|
|
9ac5a1ed3f | ||
|
|
d4eaec5cbd | ||
|
|
6ae7a790f2 | ||
|
|
49dfbc82d9 | ||
|
|
57a89cf0cc | ||
|
|
50201d63c2 | ||
|
|
d11b39282b | ||
|
|
bd58eea8ea | ||
|
|
a5811a2d7d | ||
|
|
a680f80ed9 | ||
|
|
10fbdc2c4a | ||
|
|
1444fbe104 |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'jsonfile', 'sqlite' ]
|
||||
store: [ 'jsonfile', 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
||||
@@ -130,3 +130,10 @@ issues:
|
||||
- path: mock\.go
|
||||
linters:
|
||||
- nilnil
|
||||
# Exclude specific deprecation warnings for grpc methods
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.DialContext is deprecated"
|
||||
- linters:
|
||||
- staticcheck
|
||||
text: "grpc.WithBlock is deprecated"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
identity and expression, level of experience, education, socioeconomic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
|
||||
@@ -57,15 +57,17 @@ type Client struct {
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
uiVersion: uiVersion,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
@@ -88,6 +90,9 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
var ctx context.Context
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
|
||||
//nolint
|
||||
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
|
||||
|
||||
c.ctxCancelLock.Lock()
|
||||
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
|
||||
defer c.ctxCancel()
|
||||
|
||||
@@ -3,13 +3,14 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
)
|
||||
|
||||
var debugCmd = &cobra.Command{
|
||||
@@ -58,7 +59,7 @@ var forCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -79,14 +80,14 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
}
|
||||
|
||||
func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
level := parseLogLevel(args[0])
|
||||
level := server.ParseLogLevel(args[0])
|
||||
if level == proto.LogLevel_UNKNOWN {
|
||||
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
|
||||
}
|
||||
@@ -102,34 +103,13 @@ func setLogLevel(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) proto.LogLevel {
|
||||
switch strings.ToLower(level) {
|
||||
case "panic":
|
||||
return proto.LogLevel_PANIC
|
||||
case "fatal":
|
||||
return proto.LogLevel_FATAL
|
||||
case "error":
|
||||
return proto.LogLevel_ERROR
|
||||
case "warn":
|
||||
return proto.LogLevel_WARN
|
||||
case "info":
|
||||
return proto.LogLevel_INFO
|
||||
case "debug":
|
||||
return proto.LogLevel_DEBUG
|
||||
case "trace":
|
||||
return proto.LogLevel_TRACE
|
||||
default:
|
||||
return proto.LogLevel_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
duration, err := time.ParseDuration(args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration format: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd.Context())
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -137,18 +117,33 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
restoreUp := stat.Status == string(internal.StatusConnected) || stat.Status == string(internal.StatusConnecting)
|
||||
|
||||
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
|
||||
Level: proto.LogLevel_TRACE,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set log level to trace: %v", status.Convert(err).Message())
|
||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||
if !initialLevelTrace {
|
||||
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
|
||||
Level: proto.LogLevel_TRACE,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Log level set to trace.")
|
||||
}
|
||||
cmd.Println("Log level set to trace.")
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -175,10 +170,22 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
cmd.Println("Netbird down")
|
||||
|
||||
// TODO reset log level
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if restoreUp {
|
||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Netbird up")
|
||||
}
|
||||
|
||||
if !initialLevelTrace {
|
||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
|
||||
@@ -2,9 +2,10 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
|
||||
@@ -353,8 +353,11 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func getClient(ctx context.Context) (*grpc.ClientConn, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
|
||||
@@ -49,7 +49,7 @@ func init() {
|
||||
}
|
||||
|
||||
func routesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func routesList(cmd *cobra.Command, _ []string) error {
|
||||
}
|
||||
|
||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -106,7 +106,7 @@ func routesSelect(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd.Context())
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
@@ -69,10 +70,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Debug("creating an iptables firewall manager")
|
||||
log.Info("creating an iptables firewall manager")
|
||||
fm, errFw = nbiptables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", errFw)
|
||||
}
|
||||
case NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager")
|
||||
log.Info("creating an nftables firewall manager")
|
||||
fm, errFw = nbnftables.Create(context, iface)
|
||||
if errFw != nil {
|
||||
log.Errorf("failed to create nftables manager: %s", errFw)
|
||||
}
|
||||
default:
|
||||
errFw = fmt.Errorf("no firewall manager found")
|
||||
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
@@ -85,16 +85,58 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
useIPTABLES := false
|
||||
var iptablesChains []string
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err == nil && isIptablesClientAvailable(ip) {
|
||||
major, minor, _ := ip.GetIptablesVersion()
|
||||
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
|
||||
if major < 1 || (major == 1 && minor < 8) {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
useIPTABLES = true
|
||||
|
||||
iptablesChains, err = ip.ListChains("filter")
|
||||
if err != nil {
|
||||
log.Errorf("failed to list iptables chains: %s", err)
|
||||
useIPTABLES = false
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return UNKNOWN
|
||||
nf := nftables.Conn{}
|
||||
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
if !useIPTABLES {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
// search for chains where table is filter
|
||||
// if we find one, we assume that nftables manager can be used with iptables
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name == "filter" {
|
||||
return NFTABLES
|
||||
}
|
||||
}
|
||||
|
||||
// check tables for the following constraints:
|
||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||
// 4. if we find an error we log and continue with iptables check
|
||||
nbTablesList, err := nf.ListTables()
|
||||
switch {
|
||||
case err == nil && len(iptablesChains) > 0:
|
||||
return IPTABLES
|
||||
case err == nil && len(nbTablesList) != 1:
|
||||
return NFTABLES
|
||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||
return IPTABLES
|
||||
case err != nil:
|
||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||
}
|
||||
}
|
||||
if isIptablesClientAvailable(ip) {
|
||||
|
||||
if useIPTABLES {
|
||||
return IPTABLES
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
@@ -91,6 +92,9 @@ func (c *ConnectClient) RunOniOS(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
debug.SetGCPercent(5)
|
||||
|
||||
mobileDependency := MobileDependency{
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
@@ -327,6 +331,15 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
engineConf.PreSharedKey = &preSharedKey
|
||||
}
|
||||
|
||||
port, err := freePort(config.WgPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if port != config.WgPort {
|
||||
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
|
||||
}
|
||||
engineConf.WgPort = port
|
||||
|
||||
return engineConf, nil
|
||||
}
|
||||
|
||||
@@ -376,3 +389,20 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
||||
notifier, _ := sri.(signal.ConnStateNotifier)
|
||||
return notifier
|
||||
}
|
||||
|
||||
func freePort(start int) (int, error) {
|
||||
addr := net.UDPAddr{}
|
||||
if start == 0 {
|
||||
start = iface.DefaultWgPort
|
||||
}
|
||||
for x := start; x <= 65535; x++ {
|
||||
addr.Port = x
|
||||
conn, err := net.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Close()
|
||||
return x, nil
|
||||
}
|
||||
return 0, errors.New("no free ports")
|
||||
}
|
||||
|
||||
57
client/internal/connect_test.go
Normal file
57
client/internal/connect_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_freePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port int
|
||||
want int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available",
|
||||
port: 51820,
|
||||
want: 51820,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "notavailable",
|
||||
port: 51830,
|
||||
want: 51831,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "noports",
|
||||
port: 65535,
|
||||
want: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
|
||||
if err != nil {
|
||||
t.Errorf("freePort error = %v", err)
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := freePort(tt.port)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("freePort() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -260,13 +259,10 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// todo test the deactivation logic, it seems to affect the client
|
||||
if runtime.GOOS != "ios" {
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server string) error {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -117,7 +118,8 @@ type Engine struct {
|
||||
TURNs []*stun.URI
|
||||
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
clientRoutes route.HAMap
|
||||
clientRoutesMu sync.RWMutex
|
||||
|
||||
clientCtx context.Context
|
||||
clientCancel context.CancelFunc
|
||||
@@ -133,7 +135,7 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
networkWatcher *networkmonitor.NetworkWatcher
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
@@ -150,6 +152,8 @@ type Engine struct {
|
||||
signalProbe *Probe
|
||||
relayProbe *Probe
|
||||
wgProbe *Probe
|
||||
|
||||
wgConnWorker sync.WaitGroup
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -212,7 +216,6 @@ func NewEngineWithProbes(
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
networkWatcher: networkmonitor.New(),
|
||||
mgmProbe: mgmProbe,
|
||||
signalProbe: signalProbe,
|
||||
relayProbe: relayProbe,
|
||||
@@ -229,20 +232,26 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
|
||||
// stopping network monitor first to avoid starting the engine again
|
||||
e.networkWatcher.Stop()
|
||||
if e.networkMonitor != nil {
|
||||
e.networkMonitor.Stop()
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
err := e.removeAllPeers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = nil
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
e.wgConnWorker.Wait()
|
||||
log.Infof("stopped Netbird Engine")
|
||||
return nil
|
||||
}
|
||||
@@ -259,7 +268,7 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.clientCtx, e.config.WgPort)
|
||||
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, e.config.WgPort)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
@@ -344,20 +353,8 @@ func (e *Engine) Start() error {
|
||||
e.receiveManagementEvents()
|
||||
e.receiveProbeEvents()
|
||||
|
||||
if e.config.NetworkMonitor {
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
go e.networkWatcher.Start(e.ctx, func() {
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
}
|
||||
// starting network monitor at the very last to avoid disruptions
|
||||
e.startNetworkMonitor()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -745,7 +742,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutesMu.Lock()
|
||||
e.clientRoutes = clientRoutes
|
||||
e.clientRoutesMu.Unlock()
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
@@ -879,18 +878,25 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||
}
|
||||
|
||||
e.wgConnWorker.Add(1)
|
||||
go e.connWorker(conn, peerKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
|
||||
defer e.wgConnWorker.Done()
|
||||
for {
|
||||
|
||||
// randomize starting time a bit
|
||||
min := 500
|
||||
max := 2000
|
||||
time.Sleep(time.Duration(rand.Intn(max-min)+min) * time.Millisecond)
|
||||
duration := time.Duration(rand.Intn(max-min)+min) * time.Millisecond
|
||||
select {
|
||||
case <-e.ctx.Done():
|
||||
return
|
||||
case <-time.After(duration):
|
||||
}
|
||||
|
||||
// if peer has been removed -> give up
|
||||
if !e.peerExists(peerKey) {
|
||||
@@ -977,7 +983,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
UserspaceBind: e.wgInterface.IsUserspaceBind(),
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
}
|
||||
@@ -1040,8 +1045,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.Body.GetFeaturesSupported())
|
||||
|
||||
var rosenpassPubKey []byte
|
||||
rosenpassAddr := ""
|
||||
if msg.GetBody().GetRosenpassConfig() != nil {
|
||||
@@ -1064,8 +1067,6 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.RegisterProtoSupportMeta(msg.GetBody().GetFeaturesSupported())
|
||||
|
||||
var rosenpassPubKey []byte
|
||||
rosenpassAddr := ""
|
||||
if msg.GetBody().GetRosenpassConfig() != nil {
|
||||
@@ -1088,7 +1089,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err)
|
||||
return err
|
||||
}
|
||||
conn.OnRemoteCandidate(candidate)
|
||||
|
||||
conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
|
||||
case sProto.Body_MODE:
|
||||
}
|
||||
|
||||
@@ -1282,11 +1284,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
|
||||
// GetClientRoutes returns the current routes from the route map
|
||||
func (e *Engine) GetClientRoutes() route.HAMap {
|
||||
return e.clientRoutes
|
||||
e.clientRoutesMu.RLock()
|
||||
defer e.clientRoutesMu.RUnlock()
|
||||
|
||||
return maps.Clone(e.clientRoutes)
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
|
||||
e.clientRoutesMu.RLock()
|
||||
defer e.clientRoutesMu.RUnlock()
|
||||
|
||||
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
|
||||
for id, v := range e.clientRoutes {
|
||||
routes[id.NetID()] = v
|
||||
@@ -1399,3 +1407,26 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult {
|
||||
func (e *Engine) probeTURNs() []relay.ProbeResult {
|
||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
|
||||
}
|
||||
|
||||
func (e *Engine) startNetworkMonitor() {
|
||||
if !e.config.NetworkMonitor {
|
||||
log.Infof("Network monitor is disabled, not starting")
|
||||
return
|
||||
}
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
go func() {
|
||||
err := e.networkMonitor.Start(e.ctx, func() {
|
||||
log.Infof("Network monitor detected network change, restarting engine")
|
||||
if err := e.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
if err := e.Start(); err != nil {
|
||||
log.Errorf("Failed to start engine: %v", err)
|
||||
}
|
||||
})
|
||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
||||
log.Errorf("Network monitor: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -229,6 +229,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
|
||||
engine.ctx = ctx
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -408,6 +409,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -566,6 +568,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -735,6 +738,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"))
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1003,7 +1008,9 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm")), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, string, error) {
|
||||
@@ -1042,7 +1049,7 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, err := server.NewStoreFromJson(config.Datadir, nil)
|
||||
store, _, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -2,14 +2,20 @@ package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// NetworkWatcher watches for changes in network configuration.
|
||||
type NetworkWatcher struct {
|
||||
var ErrStopped = errors.New("monitor has been stopped")
|
||||
|
||||
// NetworkMonitor watches for changes in network configuration.
|
||||
type NetworkMonitor struct {
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new network monitor.
|
||||
func New() *NetworkWatcher {
|
||||
return &NetworkWatcher{}
|
||||
func New() *NetworkMonitor {
|
||||
return &NetworkMonitor{}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return ErrStopped
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
@@ -63,7 +63,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name)
|
||||
callback()
|
||||
go callback()
|
||||
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
@@ -84,11 +84,11 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
callback()
|
||||
go callback()
|
||||
case unix.RTM_DELETE:
|
||||
if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
callback()
|
||||
go callback()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package networkmonitor
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
@@ -15,20 +16,18 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
)
|
||||
|
||||
// Start begins watching for network changes and calls the callback function and stops when a change is detected.
|
||||
func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
|
||||
if nw.cancel != nil {
|
||||
log.Warn("Network monitor: already running, stopping previous watcher")
|
||||
nw.Stop()
|
||||
}
|
||||
|
||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
||||
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
|
||||
if ctx.Err() != nil {
|
||||
log.Info("Network monitor: not starting, context is already cancelled")
|
||||
return
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
nw.mu.Lock()
|
||||
ctx, nw.cancel = context.WithCancel(ctx)
|
||||
defer nw.Stop()
|
||||
nw.mu.Unlock()
|
||||
|
||||
nw.wg.Add(1)
|
||||
defer nw.wg.Done()
|
||||
|
||||
var nexthop4, nexthop6 netip.Addr
|
||||
var intf4, intf6 *net.Interface
|
||||
@@ -56,27 +55,30 @@ func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) {
|
||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||
|
||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||
log.Errorf("Network monitor: failed to get default next hops: %v", err)
|
||||
return
|
||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||
}
|
||||
|
||||
// recover in case sys ops panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Network monitor: failed to start: %v", err)
|
||||
if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil {
|
||||
return fmt.Errorf("check change: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the network monitor.
|
||||
func (nw *NetworkWatcher) Stop() {
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
nw.mu.Lock()
|
||||
defer nw.mu.Unlock()
|
||||
|
||||
if nw.cancel != nil {
|
||||
nw.cancel()
|
||||
nw.cancel = nil
|
||||
log.Info("Network monitor: stopped")
|
||||
nw.wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return ErrStopped
|
||||
|
||||
// handle interface state changes
|
||||
case update := <-linkChan:
|
||||
@@ -47,12 +47,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
switch update.Header.Type {
|
||||
case syscall.RTM_DELLINK:
|
||||
log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name)
|
||||
callback()
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_NEWLINK:
|
||||
if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown {
|
||||
log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name)
|
||||
callback()
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -67,12 +67,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
// triggered on added/replaced routes
|
||||
case syscall.RTM_NEWROUTE:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
callback()
|
||||
go callback()
|
||||
return nil
|
||||
case syscall.RTM_DELROUTE:
|
||||
if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||
callback()
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,8 +4,9 @@ package networkmonitor
|
||||
|
||||
import "context"
|
||||
|
||||
func (nw *NetworkWatcher) Start(context.Context, func()) {
|
||||
func (nw *NetworkMonitor) Start(context.Context, func()) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nw *NetworkWatcher) Stop() {
|
||||
func (nw *NetworkMonitor) Stop() {
|
||||
}
|
||||
|
||||
@@ -48,10 +48,10 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return ErrStopped
|
||||
case <-ticker.C:
|
||||
if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) {
|
||||
callback()
|
||||
go callback()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/iface/bind"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -68,9 +69,6 @@ type ConnConfig struct {
|
||||
|
||||
NATExternalIPs []string
|
||||
|
||||
// UsesBind indicates whether the WireGuard interface is userspace and uses bind.ICEBind
|
||||
UserspaceBind bool
|
||||
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
RosenpassPubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
@@ -133,32 +131,15 @@ type Conn struct {
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgProxy wgproxy.Proxy
|
||||
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
adapter iface.TunAdapter
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
sentExtraSrflx bool
|
||||
|
||||
remoteEndpoint *net.UDPAddr
|
||||
remoteConn *ice.Conn
|
||||
|
||||
connID nbnet.ConnectionID
|
||||
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
||||
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
type meta struct {
|
||||
protoSupport signal.FeaturesSupport
|
||||
}
|
||||
|
||||
// ModeMessage represents a connection mode chosen by the peer
|
||||
type ModeMessage struct {
|
||||
// Direct indicates that it decided to use a direct connection
|
||||
Direct bool
|
||||
}
|
||||
|
||||
// GetConf returns the connection config
|
||||
func (conn *Conn) GetConf() ConnConfig {
|
||||
return conn.config
|
||||
@@ -185,7 +166,6 @@ func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
remoteModeCh: make(chan ModeMessage, 1),
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
@@ -353,7 +333,7 @@ func (conn *Conn) Open(ctx context.Context) error {
|
||||
|
||||
err = conn.agent.GatherCandidates()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("gather candidates: %v", err)
|
||||
}
|
||||
|
||||
// will block until connection succeeded
|
||||
@@ -370,14 +350,12 @@ func (conn *Conn) Open(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// dynamically set remote WireGuard port is other side specified a different one from the default one
|
||||
// dynamically set remote WireGuard port if other side specified a different one from the default one
|
||||
remoteWgPort := iface.DefaultWgPort
|
||||
if remoteOfferAnswer.WgListenPort != 0 {
|
||||
remoteWgPort = remoteOfferAnswer.WgListenPort
|
||||
}
|
||||
|
||||
conn.remoteConn = remoteConn
|
||||
|
||||
// the ice connection has been established successfully so we are ready to start the proxy
|
||||
remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
|
||||
remoteOfferAnswer.RosenpassAddr)
|
||||
@@ -435,7 +413,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
}
|
||||
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
|
||||
conn.remoteEndpoint = endpointUdpAddr
|
||||
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||
|
||||
conn.connID = nbnet.GenerateConnID()
|
||||
@@ -487,6 +464,10 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "ios" {
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
|
||||
}
|
||||
@@ -617,40 +598,39 @@ func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) err
|
||||
// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
|
||||
// and then signals them to the remote peer
|
||||
func (conn *Conn) onICECandidate(candidate ice.Candidate) {
|
||||
if candidate != nil {
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
err := conn.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
extraSrflx, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: candidate.NetworkType().String(),
|
||||
Address: candidate.Address(),
|
||||
Port: relatedAdd.Port,
|
||||
Component: candidate.Component(),
|
||||
RelAddr: relatedAdd.Address,
|
||||
RelPort: relatedAdd.Port,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
err = conn.signalCandidate(extraSrflx)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||
return
|
||||
}
|
||||
conn.sentExtraSrflx = true
|
||||
}
|
||||
}()
|
||||
// nil means candidate gathering has been ended
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
|
||||
log.Debugf("discovered local candidate %s", candidate.String())
|
||||
go func() {
|
||||
err := conn.signalCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
|
||||
}
|
||||
}()
|
||||
|
||||
if !conn.shouldSendExtraSrflxCandidate(candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||
return
|
||||
}
|
||||
conn.sentExtraSrflx = true
|
||||
|
||||
go func() {
|
||||
err = conn.signalCandidate(extraSrflx)
|
||||
if err != nil {
|
||||
log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
||||
@@ -775,7 +755,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
}
|
||||
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
|
||||
go func() {
|
||||
conn.mu.Lock()
|
||||
@@ -785,6 +765,10 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) {
|
||||
return
|
||||
}
|
||||
|
||||
if candidateViaRoutes(candidate, haRoutes) {
|
||||
return
|
||||
}
|
||||
|
||||
err := conn.agent.AddRemoteCandidate(candidate)
|
||||
if err != nil {
|
||||
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
|
||||
@@ -797,8 +781,49 @@ func (conn *Conn) GetKey() string {
|
||||
return conn.config.Key
|
||||
}
|
||||
|
||||
// RegisterProtoSupportMeta register supported proto message in the connection metadata
|
||||
func (conn *Conn) RegisterProtoSupportMeta(support []uint32) {
|
||||
protoSupport := signal.ParseFeaturesSupported(support)
|
||||
conn.meta.protoSupport = protoSupport
|
||||
func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
|
||||
if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||
relatedAdd := candidate.RelatedAddress()
|
||||
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||
Network: candidate.NetworkType().String(),
|
||||
Address: candidate.Address(),
|
||||
Port: relatedAdd.Port,
|
||||
Component: candidate.Component(),
|
||||
RelAddr: relatedAdd.Address,
|
||||
RelPort: relatedAdd.Port,
|
||||
})
|
||||
}
|
||||
|
||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||
var routePrefixes []netip.Prefix
|
||||
for _, routes := range clientRoutes {
|
||||
if len(routes) > 0 && routes[0] != nil {
|
||||
routePrefixes = append(routePrefixes, routes[0].Network)
|
||||
}
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(candidate.Address())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range routePrefixes {
|
||||
// default route is
|
||||
if prefix.Bits() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefix.Contains(addr) {
|
||||
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ func ProbeAll(
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i, uri := range relays {
|
||||
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
@@ -43,11 +43,6 @@ func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf *net.
|
||||
}
|
||||
if prefix.Addr().Is6() {
|
||||
inet = "-inet6"
|
||||
// Special case for IPv6 split default route, pointing to the wg interface fails
|
||||
// TODO: Remove once we have IPv6 support on the interface
|
||||
if prefix.Bits() == 1 {
|
||||
intf = &net.Interface{Name: "lo0"}
|
||||
}
|
||||
}
|
||||
|
||||
args := []string{"-n", action, inet, network}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<meta content="width=device-width, initial-scale=1" name="viewport"/>
|
||||
<style>
|
||||
body {
|
||||
display: flex;
|
||||
@@ -50,16 +50,17 @@
|
||||
color: black;
|
||||
}
|
||||
</style>
|
||||
<title>NetBird Login Successful</title>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="logo">
|
||||
<img src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
|
||||
<img alt="netbird_logo" src="https://img.mailinblue.com/6211297/images/content_library/original/64bd4ce82e1ea753e439b6a2.png">
|
||||
</div>
|
||||
<br>
|
||||
{{ if .Error }}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
|
||||
<circle cx="50" cy="50" r="45" fill="none" stroke="red" stroke-width="3"/>
|
||||
<svg height="50" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="50" cy="50" fill="none" r="45" stroke="red" stroke-width="3"/>
|
||||
<path d="M30 30 L70 70 M30 70 L70 30" fill="none" stroke="red" stroke-width="3"/>
|
||||
</svg>
|
||||
<div class="content">
|
||||
@@ -69,8 +70,8 @@
|
||||
{{ .Error }}.
|
||||
</div>
|
||||
{{ else }}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="50" viewBox="0 0 100 100">
|
||||
<circle cx="50" cy="50" r="45" fill="none" stroke="#5cb85c" stroke-width="3"/>
|
||||
<svg height="50" viewBox="0 0 100 100" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="50" cy="50" fill="none" r="45" stroke="#5cb85c" stroke-width="3"/>
|
||||
<path d="M30 50 L45 65 L70 35" fill="none" stroke="#5cb85c" stroke-width="5"/>
|
||||
</svg>
|
||||
<div class="content">
|
||||
|
||||
@@ -109,7 +109,6 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
|
||||
|
||||
// CloseConn doing nothing because this type of proxy implementation does not store the connection
|
||||
func (p *WGEBPFProxy) CloseConn() error {
|
||||
p.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1806,6 +1806,91 @@ func (x *DebugBundleResponse) GetPath() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type GetLogLevelRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
}
|
||||
|
||||
func (x *GetLogLevelRequest) Reset() {
|
||||
*x = GetLogLevelRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[26]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *GetLogLevelRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetLogLevelRequest) ProtoMessage() {}
|
||||
|
||||
func (x *GetLogLevelRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[26]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLogLevelRequest.ProtoReflect.Descriptor instead.
|
||||
func (*GetLogLevelRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{26}
|
||||
}
|
||||
|
||||
type GetLogLevelResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Level LogLevel `protobuf:"varint,1,opt,name=level,proto3,enum=daemon.LogLevel" json:"level,omitempty"`
|
||||
}
|
||||
|
||||
func (x *GetLogLevelResponse) Reset() {
|
||||
*x = GetLogLevelResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[27]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *GetLogLevelResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetLogLevelResponse) ProtoMessage() {}
|
||||
|
||||
func (x *GetLogLevelResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[27]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLogLevelResponse.ProtoReflect.Descriptor instead.
|
||||
func (*GetLogLevelResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{27}
|
||||
}
|
||||
|
||||
func (x *GetLogLevelResponse) GetLevel() LogLevel {
|
||||
if x != nil {
|
||||
return x.Level
|
||||
}
|
||||
return LogLevel_UNKNOWN
|
||||
}
|
||||
|
||||
type SetLogLevelRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -1817,7 +1902,7 @@ type SetLogLevelRequest struct {
|
||||
func (x *SetLogLevelRequest) Reset() {
|
||||
*x = SetLogLevelRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[26]
|
||||
mi := &file_daemon_proto_msgTypes[28]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1830,7 +1915,7 @@ func (x *SetLogLevelRequest) String() string {
|
||||
func (*SetLogLevelRequest) ProtoMessage() {}
|
||||
|
||||
func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[26]
|
||||
mi := &file_daemon_proto_msgTypes[28]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1843,7 +1928,7 @@ func (x *SetLogLevelRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use SetLogLevelRequest.ProtoReflect.Descriptor instead.
|
||||
func (*SetLogLevelRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{26}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{28}
|
||||
}
|
||||
|
||||
func (x *SetLogLevelRequest) GetLevel() LogLevel {
|
||||
@@ -1862,7 +1947,7 @@ type SetLogLevelResponse struct {
|
||||
func (x *SetLogLevelResponse) Reset() {
|
||||
*x = SetLogLevelResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_daemon_proto_msgTypes[27]
|
||||
mi := &file_daemon_proto_msgTypes[29]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -1875,7 +1960,7 @@ func (x *SetLogLevelResponse) String() string {
|
||||
func (*SetLogLevelResponse) ProtoMessage() {}
|
||||
|
||||
func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[27]
|
||||
mi := &file_daemon_proto_msgTypes[29]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -1888,7 +1973,7 @@ func (x *SetLogLevelResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use SetLogLevelResponse.ProtoReflect.Descriptor instead.
|
||||
func (*SetLogLevelResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{27}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{29}
|
||||
}
|
||||
|
||||
var File_daemon_proto protoreflect.FileDescriptor
|
||||
@@ -2138,67 +2223,77 @@ var file_daemon_proto_rawDesc = []byte{
|
||||
0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22,
|
||||
0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
|
||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,
|
||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32,
|
||||
0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65,
|
||||
0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a,
|
||||
0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55,
|
||||
0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49,
|
||||
0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52,
|
||||
0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a,
|
||||
0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43,
|
||||
0x45, 0x10, 0x07, 0x32, 0xee, 0x05, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14,
|
||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
|
||||
0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a,
|
||||
0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e,
|
||||
0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f,
|
||||
0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70,
|
||||
0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61,
|
||||
0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61,
|
||||
0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64,
|
||||
0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e,
|
||||
0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a,
|
||||
0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61,
|
||||
0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22,
|
||||
0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
|
||||
0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a,
|
||||
0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70,
|
||||
0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c,
|
||||
0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a,
|
||||
0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41,
|
||||
0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08,
|
||||
0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f,
|
||||
0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a,
|
||||
0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f,
|
||||
0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67,
|
||||
0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67,
|
||||
0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74,
|
||||
0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
|
||||
0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f,
|
||||
0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
|
||||
0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55,
|
||||
0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39,
|
||||
0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
|
||||
0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77,
|
||||
0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42,
|
||||
0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47,
|
||||
0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
|
||||
0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
|
||||
0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
|
||||
0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
|
||||
0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
||||
0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
|
||||
0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
|
||||
0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
|
||||
0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c,
|
||||
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
|
||||
0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00,
|
||||
0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12,
|
||||
0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
|
||||
0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61,
|
||||
0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
|
||||
0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65,
|
||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
||||
0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c,
|
||||
0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d,
|
||||
0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
|
||||
0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
|
||||
0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65,
|
||||
0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
|
||||
0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
|
||||
0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
|
||||
0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
|
||||
0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
|
||||
0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
|
||||
0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
|
||||
0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a,
|
||||
0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
|
||||
0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
|
||||
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74,
|
||||
0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
|
||||
0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
|
||||
0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
|
||||
0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -2214,7 +2309,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 28)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 30)
|
||||
var file_daemon_proto_goTypes = []interface{}{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(*LoginRequest)(nil), // 1: daemon.LoginRequest
|
||||
@@ -2243,16 +2338,18 @@ var file_daemon_proto_goTypes = []interface{}{
|
||||
(*Route)(nil), // 24: daemon.Route
|
||||
(*DebugBundleRequest)(nil), // 25: daemon.DebugBundleRequest
|
||||
(*DebugBundleResponse)(nil), // 26: daemon.DebugBundleResponse
|
||||
(*SetLogLevelRequest)(nil), // 27: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 28: daemon.SetLogLevelResponse
|
||||
(*timestamp.Timestamp)(nil), // 29: google.protobuf.Timestamp
|
||||
(*duration.Duration)(nil), // 30: google.protobuf.Duration
|
||||
(*GetLogLevelRequest)(nil), // 27: daemon.GetLogLevelRequest
|
||||
(*GetLogLevelResponse)(nil), // 28: daemon.GetLogLevelResponse
|
||||
(*SetLogLevelRequest)(nil), // 29: daemon.SetLogLevelRequest
|
||||
(*SetLogLevelResponse)(nil), // 30: daemon.SetLogLevelResponse
|
||||
(*timestamp.Timestamp)(nil), // 31: google.protobuf.Timestamp
|
||||
(*duration.Duration)(nil), // 32: google.protobuf.Duration
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
19, // 0: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
29, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
29, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
30, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
31, // 1: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
31, // 2: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
32, // 3: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
16, // 4: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
15, // 5: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
14, // 6: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
@@ -2260,34 +2357,37 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
17, // 8: daemon.FullStatus.relays:type_name -> daemon.RelayState
|
||||
18, // 9: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
24, // 10: daemon.ListRoutesResponse.routes:type_name -> daemon.Route
|
||||
0, // 11: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
1, // 12: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
3, // 13: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
5, // 14: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
7, // 15: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
9, // 16: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
11, // 17: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
20, // 18: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
|
||||
22, // 19: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
22, // 20: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
25, // 21: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
27, // 22: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
2, // 23: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
4, // 24: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
6, // 25: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
8, // 26: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
10, // 27: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
12, // 28: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
21, // 29: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
|
||||
23, // 30: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
23, // 31: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
26, // 32: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
28, // 33: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
23, // [23:34] is the sub-list for method output_type
|
||||
12, // [12:23] is the sub-list for method input_type
|
||||
12, // [12:12] is the sub-list for extension type_name
|
||||
12, // [12:12] is the sub-list for extension extendee
|
||||
0, // [0:12] is the sub-list for field type_name
|
||||
0, // 11: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel
|
||||
0, // 12: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel
|
||||
1, // 13: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
3, // 14: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
5, // 15: daemon.DaemonService.Up:input_type -> daemon.UpRequest
|
||||
7, // 16: daemon.DaemonService.Status:input_type -> daemon.StatusRequest
|
||||
9, // 17: daemon.DaemonService.Down:input_type -> daemon.DownRequest
|
||||
11, // 18: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest
|
||||
20, // 19: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest
|
||||
22, // 20: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
22, // 21: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest
|
||||
25, // 22: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest
|
||||
27, // 23: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest
|
||||
29, // 24: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest
|
||||
2, // 25: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
4, // 26: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
6, // 27: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
8, // 28: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
10, // 29: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
12, // 30: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
21, // 31: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse
|
||||
23, // 32: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
23, // 33: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse
|
||||
26, // 34: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
28, // 35: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
30, // 36: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
25, // [25:37] is the sub-list for method output_type
|
||||
13, // [13:25] is the sub-list for method input_type
|
||||
13, // [13:13] is the sub-list for extension type_name
|
||||
13, // [13:13] is the sub-list for extension extendee
|
||||
0, // [0:13] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_daemon_proto_init() }
|
||||
@@ -2609,7 +2709,7 @@ func file_daemon_proto_init() {
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetLogLevelRequest); i {
|
||||
switch v := v.(*GetLogLevelRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
@@ -2621,6 +2721,30 @@ func file_daemon_proto_init() {
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*GetLogLevelResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetLogLevelRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_daemon_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*SetLogLevelResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -2640,7 +2764,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_daemon_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 28,
|
||||
NumMessages: 30,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -40,6 +40,9 @@ service DaemonService {
|
||||
// DebugBundle creates a debug bundle
|
||||
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
|
||||
|
||||
// GetLogLevel gets the log level of the daemon
|
||||
rpc GetLogLevel(GetLogLevelRequest) returns (GetLogLevelResponse) {}
|
||||
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {}
|
||||
};
|
||||
@@ -256,6 +259,13 @@ enum LogLevel {
|
||||
TRACE = 7;
|
||||
}
|
||||
|
||||
message GetLogLevelRequest {
|
||||
}
|
||||
|
||||
message GetLogLevelResponse {
|
||||
LogLevel level = 1;
|
||||
}
|
||||
|
||||
message SetLogLevelRequest {
|
||||
LogLevel level = 1;
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@ type DaemonServiceClient interface {
|
||||
DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error)
|
||||
// DebugBundle creates a debug bundle
|
||||
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
|
||||
// GetLogLevel gets the log level of the daemon
|
||||
GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error)
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error)
|
||||
}
|
||||
@@ -141,6 +143,15 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) {
|
||||
out := new(GetLogLevelResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
|
||||
out := new(SetLogLevelResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
|
||||
@@ -175,6 +186,8 @@ type DaemonServiceServer interface {
|
||||
DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error)
|
||||
// DebugBundle creates a debug bundle
|
||||
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
|
||||
// GetLogLevel gets the log level of the daemon
|
||||
GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error)
|
||||
// SetLogLevel sets the log level of the daemon
|
||||
SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
@@ -214,6 +227,9 @@ func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectR
|
||||
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetLogLevel not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented")
|
||||
}
|
||||
@@ -410,6 +426,24 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetLogLevelRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).GetLogLevel(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetLogLevel",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SetLogLevelRequest)
|
||||
if err := dec(in); err != nil {
|
||||
@@ -475,6 +509,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "DebugBundle",
|
||||
Handler: _DaemonService_DebugBundle_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetLogLevel",
|
||||
Handler: _DaemonService_GetLogLevel_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SetLogLevel",
|
||||
Handler: _DaemonService_SetLogLevel_Handler,
|
||||
|
||||
@@ -121,6 +121,12 @@ func (s *Server) anonymize(reader io.Reader, writer io.WriteCloser, errChan chan
|
||||
}
|
||||
}
|
||||
|
||||
// GetLogLevel gets the current logging level for the server.
|
||||
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
|
||||
level := ParseLogLevel(log.GetLevel().String())
|
||||
return &proto.GetLogLevelResponse{Level: level}, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the server.
|
||||
func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) {
|
||||
level, err := log.ParseLevel(req.Level.String())
|
||||
|
||||
28
client/server/log.go
Normal file
28
client/server/log.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func ParseLogLevel(level string) proto.LogLevel {
|
||||
switch strings.ToLower(level) {
|
||||
case "panic":
|
||||
return proto.LogLevel_PANIC
|
||||
case "fatal":
|
||||
return proto.LogLevel_FATAL
|
||||
case "error":
|
||||
return proto.LogLevel_ERROR
|
||||
case "warn":
|
||||
return proto.LogLevel_WARN
|
||||
case "info":
|
||||
return proto.LogLevel_INFO
|
||||
case "debug":
|
||||
return proto.LogLevel_DEBUG
|
||||
case "trace":
|
||||
return proto.LogLevel_TRACE
|
||||
default:
|
||||
return proto.LogLevel_UNKNOWN
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,7 @@ const (
|
||||
maxRetryIntervalVar = "NB_CONN_MAX_RETRY_INTERVAL_TIME"
|
||||
maxRetryTimeVar = "NB_CONN_MAX_RETRY_TIME_TIME"
|
||||
retryMultiplierVar = "NB_CONN_RETRY_MULTIPLIER"
|
||||
defaultInitialRetryTime = 14 * 24 * time.Hour
|
||||
defaultInitialRetryTime = 30 * time.Minute
|
||||
defaultMaxRetryInterval = 60 * time.Minute
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
@@ -106,10 +106,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, err := server.NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
@@ -20,6 +20,9 @@ const OsVersionCtxKey = "OsVersion"
|
||||
// OsNameCtxKey context key for operating system name
|
||||
const OsNameCtxKey = "OsName"
|
||||
|
||||
// UiVersionCtxKey context key for user UI version
|
||||
const UiVersionCtxKey = "user-agent"
|
||||
|
||||
type NetworkAddress struct {
|
||||
NetIP netip.Prefix
|
||||
Mac string
|
||||
|
||||
@@ -28,10 +28,18 @@ func GetInfo(ctx context.Context) *Info {
|
||||
kernelVersion = osInfo[2]
|
||||
}
|
||||
|
||||
gio := &Info{Kernel: kernel, Platform: "unknown", OS: "android", OSVersion: osVersion(), GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: kernelVersion}
|
||||
gio.Hostname = extractDeviceName(ctx, "android")
|
||||
gio.WiretrusteeVersion = version.NetbirdVersion()
|
||||
gio.UIVersion = extractUserAgent(ctx)
|
||||
gio := &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: kernel,
|
||||
Platform: "unknown",
|
||||
OS: "android",
|
||||
OSVersion: osVersion(),
|
||||
Hostname: extractDeviceName(ctx, "android"),
|
||||
CPUs: runtime.NumCPU(),
|
||||
WiretrusteeVersion: version.NetbirdVersion(),
|
||||
UIVersion: extractUIVersion(ctx),
|
||||
KernelVersion: kernelVersion,
|
||||
}
|
||||
|
||||
return gio
|
||||
}
|
||||
@@ -45,6 +53,14 @@ func osVersion() string {
|
||||
return run("/system/bin/getprop", "ro.build.version.release")
|
||||
}
|
||||
|
||||
func extractUIVersion(ctx context.Context) string {
|
||||
v, ok := ctx.Value(UiVersionCtxKey).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func run(name string, arg ...string) string {
|
||||
cmd := exec.Command(name, arg...)
|
||||
cmd.Stdin = strings.NewReader("some")
|
||||
|
||||
@@ -399,6 +399,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
s.setDisconnectedStatus()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -426,17 +427,7 @@ func (s *serviceClient) updateStatus() error {
|
||||
s.mRoutes.Enable()
|
||||
systrayIconState = true
|
||||
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
|
||||
s.connected = false
|
||||
if s.isUpdateIconActive {
|
||||
systray.SetIcon(s.icUpdateDisconnected)
|
||||
} else {
|
||||
systray.SetIcon(s.icDisconnected)
|
||||
}
|
||||
systray.SetTooltip("NetBird (Disconnected)")
|
||||
s.mStatus.SetTitle("Disconnected")
|
||||
s.mDown.Disable()
|
||||
s.mUp.Enable()
|
||||
s.mRoutes.Disable()
|
||||
s.setDisconnectedStatus()
|
||||
systrayIconState = false
|
||||
}
|
||||
|
||||
@@ -481,6 +472,20 @@ func (s *serviceClient) updateStatus() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) setDisconnectedStatus() {
|
||||
s.connected = false
|
||||
if s.isUpdateIconActive {
|
||||
systray.SetIcon(s.icUpdateDisconnected)
|
||||
} else {
|
||||
systray.SetIcon(s.icDisconnected)
|
||||
}
|
||||
systray.SetTooltip("NetBird (Disconnected)")
|
||||
s.mStatus.SetTitle("Disconnected")
|
||||
s.mDown.Disable()
|
||||
s.mUp.Enable()
|
||||
s.mRoutes.Disable()
|
||||
}
|
||||
|
||||
func (s *serviceClient) onTrayReady() {
|
||||
systray.SetIcon(s.icDisconnected)
|
||||
systray.SetTooltip("NetBird")
|
||||
|
||||
150
go.mod
150
go.mod
@@ -1,33 +1,31 @@
|
||||
module github.com/netbirdio/netbird
|
||||
|
||||
go 1.21
|
||||
|
||||
toolchain go1.21.0
|
||||
go 1.21.0
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.4.0
|
||||
github.com/cenkalti/backoff/v4 v4.1.3
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
github.com/cloudflare/circl v1.3.3 // indirect
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/golang/protobuf v1.5.3
|
||||
github.com/golang/protobuf v1.5.4
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7
|
||||
github.com/onsi/ginkgo v1.16.5
|
||||
github.com/onsi/gomega v1.18.1
|
||||
github.com/onsi/gomega v1.27.6
|
||||
github.com/pion/ice/v3 v3.0.2
|
||||
github.com/rs/cors v1.8.0
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
|
||||
golang.org/x/crypto v0.21.0
|
||||
golang.org/x/sys v0.18.0
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/sys v0.20.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
google.golang.org/grpc v1.56.3
|
||||
google.golang.org/protobuf v1.31.0
|
||||
google.golang.org/grpc v1.64.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
)
|
||||
|
||||
@@ -35,7 +33,7 @@ require (
|
||||
fyne.io/fyne/v2 v2.1.4
|
||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/cilium/ebpf v0.11.0
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/v3 v3.1.1
|
||||
@@ -44,23 +42,24 @@ require (
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.5.9
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/martian/v3 v3.0.0
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/gorilla/websocket v1.5.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/magiconair/properties v1.8.5
|
||||
github.com/magiconair/properties v1.8.7
|
||||
github.com/mattn/go-sqlite3 v1.14.19
|
||||
github.com/mdlayher/socket v0.4.1
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240415094251-369eb33c9b01
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
@@ -68,42 +67,60 @@ require (
|
||||
github.com/pion/stun/v2 v2.0.0
|
||||
github.com/pion/transport/v3 v3.0.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.14.0
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/quic-go/quic-go v0.45.0
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/testcontainers/testcontainers-go v0.31.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
|
||||
github.com/things-go/go-socks5 v0.0.4
|
||||
github.com/yusufpapurcu/wmi v1.2.3
|
||||
github.com/yusufpapurcu/wmi v1.2.4
|
||||
github.com/zcalusic/sysinfo v1.0.2
|
||||
go.opentelemetry.io/otel v1.11.1
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
|
||||
go.opentelemetry.io/otel/metric v0.33.0
|
||||
go.opentelemetry.io/otel/sdk/metric v0.33.0
|
||||
go.opentelemetry.io/otel v1.26.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
|
||||
go.opentelemetry.io/otel/metric v1.26.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.26.0
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028
|
||||
golang.org/x/net v0.23.0
|
||||
golang.org/x/oauth2 v0.8.0
|
||||
golang.org/x/sync v0.3.0
|
||||
golang.org/x/term v0.18.0
|
||||
google.golang.org/api v0.126.0
|
||||
golang.org/x/net v0.25.0
|
||||
golang.org/x/oauth2 v0.19.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/term v0.20.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.5.7
|
||||
gorm.io/driver/sqlite v1.5.3
|
||||
gorm.io/gorm v1.25.4
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
|
||||
nhooyr.io/websocket v1.8.11
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute v1.19.3 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.2.3 // indirect
|
||||
github.com/BurntSushi/toml v1.2.1 // indirect
|
||||
cloud.google.com/go/auth v0.3.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
github.com/BurntSushi/toml v1.3.2 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/containerd v1.7.16 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v26.1.3+incompatible // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fredbi/uri v0.0.0-20181227131451-3dcfdacbaaf3 // indirect
|
||||
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||
@@ -113,59 +130,84 @@ require (
|
||||
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20211024062804-40e447a793be // indirect
|
||||
github.com/go-logr/logr v1.2.3 // indirect
|
||||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.5 // indirect
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/google/btree v1.0.1 // indirect
|
||||
github.com/google/s2a-go v0.1.4 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/klauspost/compress v1.17.8 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
|
||||
github.com/mdlayher/genetlink v1.3.2 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||
github.com/moby/sys/sequential v0.5.0 // indirect
|
||||
github.com/moby/sys/user v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.0 // indirect
|
||||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/nxadm/tail v1.4.8 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||
github.com/pegasus-kv/thrift v0.13.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/mdns v0.0.12 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.3.0 // indirect
|
||||
github.com/prometheus/common v0.37.0 // indirect
|
||||
github.com/prometheus/procfs v0.8.0 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.53.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.0 // indirect
|
||||
github.com/shirou/gopsutil/v3 v3.24.4 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/spf13/cast v1.5.0 // indirect
|
||||
github.com/srwiley/oksvg v0.0.0-20200311192757-870daf9aa564 // indirect
|
||||
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.14 // indirect
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
github.com/yuin/goldmark v1.4.13 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.11.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/image v0.10.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.21.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect
|
||||
k8s.io/apimachinery v0.23.16 // indirect
|
||||
k8s.io/apimachinery v0.26.2 // indirect
|
||||
)
|
||||
|
||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0
|
||||
|
||||
@@ -3,6 +3,7 @@ package iface
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -79,8 +80,19 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, addr, addrs[0].String())
|
||||
var found bool
|
||||
for _, a := range addrs {
|
||||
prefix, err := netip.ParsePrefix(a.String())
|
||||
assert.NoError(t, err)
|
||||
if prefix.Addr().Is4() {
|
||||
found = true
|
||||
assert.Equal(t, addr, prefix.String())
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Fatal("v4 address not found")
|
||||
}
|
||||
}
|
||||
|
||||
func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build !ios
|
||||
// +build !ios
|
||||
|
||||
package iface
|
||||
|
||||
@@ -121,13 +120,19 @@ func (t *tunDevice) Wrapper() *DeviceWrapper {
|
||||
func (t *tunDevice) assignAddr() error {
|
||||
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Infof(`adding address command "%v" failed with output %s and error: `, cmd.String(), out)
|
||||
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
return err
|
||||
}
|
||||
|
||||
// dummy ipv6 so routing works
|
||||
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
|
||||
}
|
||||
|
||||
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
|
||||
if out, err := routeCmd.CombinedOutput(); err != nil {
|
||||
log.Printf(`adding route command "%v" failed with output %s and error: `, routeCmd.String(), out)
|
||||
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -62,10 +62,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var shortDown = "Rollback SQLite store to JSON file store. Please make a backup of the SQLite file before running this command."
|
||||
@@ -39,16 +40,16 @@ var downCmd = &cobra.Command{
|
||||
return fmt.Errorf("%s already exists, couldn't continue the operation", fileStorePath)
|
||||
}
|
||||
|
||||
sqlstore, err := server.NewSqliteStore(mgmtDataDir, nil)
|
||||
sqlStore, err := server.NewSqliteStore(mgmtDataDir, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err)
|
||||
}
|
||||
|
||||
sqliteStoreAccounts := len(sqlstore.GetAllAccounts())
|
||||
sqliteStoreAccounts := len(sqlStore.GetAllAccounts())
|
||||
log.Infof("%d account will be migrated from sqlite store %s to file store %s",
|
||||
sqliteStoreAccounts, sqliteStorePath, fileStorePath)
|
||||
|
||||
store, err := server.NewFilestoreFromSqliteStore(sqlstore, mgmtDataDir, nil)
|
||||
store, err := server.NewFilestoreFromSqliteStore(sqlStore, mgmtDataDir, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating file store: %s: %v", mgmtDataDir, err)
|
||||
}
|
||||
|
||||
@@ -132,6 +132,7 @@ type AccountManager interface {
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||
CancelPeerRoutines(peer *nbpeer.Peer) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -241,6 +242,11 @@ type Account struct {
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
}
|
||||
|
||||
// Subclass used in gorm to only load settings and not whole account
|
||||
type AccountSettings struct {
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
@@ -1768,6 +1774,8 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
log.Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
}
|
||||
@@ -1788,8 +1796,10 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock()
|
||||
defer unlock()
|
||||
log.Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||
@@ -1840,6 +1850,9 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -1853,7 +1866,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.I
|
||||
|
||||
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
|
||||
if err != nil {
|
||||
return nil, nil, mapError(err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
|
||||
@@ -1867,6 +1880,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.I
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return status.Errorf(status.Unauthenticated, "peer not registered")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1951,6 +1967,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
|
||||
am.updateAccountPeers(updatedAccount)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
return am.Store.GetPostureCheckByChecksDefinition(accountID, checks)
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
|
||||
@@ -48,8 +48,8 @@ func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, p
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) {
|
||||
return false, false
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
|
||||
@@ -1294,6 +1294,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
userID := "account_creator"
|
||||
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
@@ -1655,6 +1656,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
@@ -1707,6 +1709,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
@@ -1750,6 +1753,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
@@ -2267,21 +2271,29 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
|
||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatal("failed to init testing account")
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
|
||||
@@ -200,10 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
@@ -57,18 +58,18 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
|
||||
}
|
||||
|
||||
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
|
||||
func NewFilestoreFromSqliteStore(sqlitestore *SqliteStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
store, err := NewFileStore(dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(sqlitestore.GetInstallationID())
|
||||
err = store.SaveInstallationID(sqlStore.GetInstallationID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range sqlitestore.GetAllAccounts() {
|
||||
for _, account := range sqlStore.GetAllAccounts() {
|
||||
store.Accounts[account.Id] = account
|
||||
}
|
||||
|
||||
@@ -508,7 +509,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
if _, ok := account.Peers[peerID]; !ok {
|
||||
delete(s.PeerID2AccountID, peerID)
|
||||
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
|
||||
return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerID)
|
||||
return nil, status.NewPeerNotFoundError(peerID)
|
||||
}
|
||||
|
||||
return account.Copy(), nil
|
||||
@@ -552,7 +553,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
|
||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@@ -572,7 +573,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
if stale {
|
||||
delete(s.PeerKeyID2AccountID, peerKey)
|
||||
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
|
||||
return nil, status.Errorf(status.NotFound, "provided peer doesn't exists %s", peerKey)
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
return account.Copy(), nil
|
||||
@@ -584,12 +585,71 @@ func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
|
||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||
return "", status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return "", status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||
if !ok {
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peer.Key == peerKey {
|
||||
return peer.Copy(), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Settings.Copy(), nil
|
||||
}
|
||||
|
||||
// GetInstallationID returns the installation ID from the store
|
||||
func (s *FileStore) GetInstallationID() string {
|
||||
return s.InstallationID
|
||||
@@ -667,6 +727,10 @@ func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.T
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
||||
}
|
||||
|
||||
// Close the FileStore persisting data to disk
|
||||
func (s *FileStore) Close() error {
|
||||
s.mux.Lock()
|
||||
|
||||
@@ -59,6 +59,7 @@ func TestStalePeerIndices(t *testing.T) {
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
if store.Accounts == nil || len(store.Accounts) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
@@ -87,6 +88,7 @@ func TestNewStore(t *testing.T) {
|
||||
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
@@ -135,6 +137,8 @@ func TestDeleteAccount(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
var account *Account
|
||||
for _, a := range store.Accounts {
|
||||
account = a
|
||||
@@ -179,6 +183,7 @@ func TestDeleteAccount(t *testing.T) {
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
@@ -436,6 +441,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
|
||||
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
|
||||
|
||||
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
|
||||
|
||||
@@ -245,6 +245,11 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
||||
return &GroupLinkError{"route", string(r.NetID)}
|
||||
}
|
||||
}
|
||||
for _, g := range r.PeerGroups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"route", string(r.NetID)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DNS links
|
||||
|
||||
@@ -70,6 +70,11 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
"grp-for-route",
|
||||
"route",
|
||||
},
|
||||
{
|
||||
"route with peer groups",
|
||||
"grp-for-route2",
|
||||
"route",
|
||||
},
|
||||
{
|
||||
"name server groups",
|
||||
"grp-for-name-server-grp",
|
||||
@@ -138,6 +143,14 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
Peers: make([]string, 0),
|
||||
}
|
||||
|
||||
groupForRoute2 := &nbgroup.Group{
|
||||
ID: "grp-for-route2",
|
||||
AccountID: "account-id",
|
||||
Name: "Group for route",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
Peers: make([]string, 0),
|
||||
}
|
||||
|
||||
groupForNameServerGroups := &nbgroup.Group{
|
||||
ID: "grp-for-name-server-grp",
|
||||
AccountID: "account-id",
|
||||
@@ -183,6 +196,11 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
Groups: []string{groupForRoute.ID},
|
||||
}
|
||||
|
||||
routePeerGroupResource := &route.Route{
|
||||
ID: "example route with peer groups",
|
||||
PeerGroups: []string{groupForRoute2.ID},
|
||||
}
|
||||
|
||||
nameServerGroup := &nbdns.NameServerGroup{
|
||||
ID: "example name server group",
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
@@ -209,6 +227,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
}
|
||||
account := newAccountWithId(accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
account.Policies = append(account.Policies, policy)
|
||||
account.SetupKeys[setupKey.Id] = setupKey
|
||||
@@ -220,6 +239,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
|
||||
|
||||
@@ -136,7 +136,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
|
||||
if err != nil {
|
||||
return err
|
||||
return mapError(err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||
@@ -368,7 +368,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("failed logging in peer %s", peerKey)
|
||||
log.Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,10 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
@@ -59,7 +57,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
|
||||
postureChecks := []*api.PostureCheck{}
|
||||
for _, postureCheck := range accountPostureChecks {
|
||||
postureChecks = append(postureChecks, toPostureChecksResponse(postureCheck))
|
||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks)
|
||||
@@ -130,7 +128,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPostureChecksResponse(postureChecks))
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
||||
// DeletePostureCheck handles posture check deletion request
|
||||
@@ -178,55 +176,26 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
return
|
||||
}
|
||||
|
||||
if postureChecksID == "" {
|
||||
postureChecksID = xid.New().String()
|
||||
}
|
||||
|
||||
postureChecks := posture.Checks{
|
||||
ID: postureChecksID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
if nbVersionCheck := req.Checks.NbVersionCheck; nbVersionCheck != nil {
|
||||
postureChecks.Checks.NBVersionCheck = &posture.NBVersionCheck{
|
||||
MinVersion: nbVersionCheck.MinVersion,
|
||||
}
|
||||
}
|
||||
|
||||
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
|
||||
postureChecks.Checks.OSVersionCheck = &posture.OSVersionCheck{
|
||||
Android: (*posture.MinVersionCheck)(osVersionCheck.Android),
|
||||
Darwin: (*posture.MinVersionCheck)(osVersionCheck.Darwin),
|
||||
Ios: (*posture.MinVersionCheck)(osVersionCheck.Ios),
|
||||
Linux: (*posture.MinKernelVersionCheck)(osVersionCheck.Linux),
|
||||
Windows: (*posture.MinKernelVersionCheck)(osVersionCheck.Windows),
|
||||
}
|
||||
}
|
||||
|
||||
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
|
||||
if p.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
return
|
||||
}
|
||||
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
|
||||
}
|
||||
|
||||
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
|
||||
postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
|
||||
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
|
||||
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
||||
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
|
||||
@@ -294,105 +263,3 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
|
||||
var checks api.Checks
|
||||
|
||||
if postureChecks.Checks.NBVersionCheck != nil {
|
||||
checks.NbVersionCheck = &api.NBVersionCheck{
|
||||
MinVersion: postureChecks.Checks.NBVersionCheck.MinVersion,
|
||||
}
|
||||
}
|
||||
|
||||
if postureChecks.Checks.OSVersionCheck != nil {
|
||||
checks.OsVersionCheck = &api.OSVersionCheck{
|
||||
Android: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Android),
|
||||
Darwin: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Darwin),
|
||||
Ios: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Ios),
|
||||
Linux: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Linux),
|
||||
Windows: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Windows),
|
||||
}
|
||||
}
|
||||
|
||||
if postureChecks.Checks.GeoLocationCheck != nil {
|
||||
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
|
||||
}
|
||||
|
||||
if postureChecks.Checks.PeerNetworkRangeCheck != nil {
|
||||
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
|
||||
}
|
||||
|
||||
return &api.PostureCheck{
|
||||
Id: postureChecks.ID,
|
||||
Name: postureChecks.Name,
|
||||
Description: &postureChecks.Description,
|
||||
Checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
func toGeoLocationCheckResponse(geoLocationCheck *posture.GeoLocationCheck) *api.GeoLocationCheck {
|
||||
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
|
||||
for _, loc := range geoLocationCheck.Locations {
|
||||
l := loc // make G601 happy
|
||||
var cityName *string
|
||||
if loc.CityName != "" {
|
||||
cityName = &l.CityName
|
||||
}
|
||||
locations = append(locations, api.Location{
|
||||
CityName: cityName,
|
||||
CountryCode: loc.CountryCode,
|
||||
})
|
||||
}
|
||||
|
||||
return &api.GeoLocationCheck{
|
||||
Action: api.GeoLocationCheckAction(geoLocationCheck.Action),
|
||||
Locations: locations,
|
||||
}
|
||||
}
|
||||
|
||||
func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *posture.GeoLocationCheck {
|
||||
locations := make([]posture.Location, 0, len(apiGeoLocationCheck.Locations))
|
||||
for _, loc := range apiGeoLocationCheck.Locations {
|
||||
cityName := ""
|
||||
if loc.CityName != nil {
|
||||
cityName = *loc.CityName
|
||||
}
|
||||
locations = append(locations, posture.Location{
|
||||
CountryCode: loc.CountryCode,
|
||||
CityName: cityName,
|
||||
})
|
||||
}
|
||||
|
||||
return &posture.GeoLocationCheck{
|
||||
Action: string(apiGeoLocationCheck.Action),
|
||||
Locations: locations,
|
||||
}
|
||||
}
|
||||
|
||||
func toPeerNetworkRangeCheckResponse(check *posture.PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
|
||||
netPrefixes := make([]string, 0, len(check.Ranges))
|
||||
for _, netPrefix := range check.Ranges {
|
||||
netPrefixes = append(netPrefixes, netPrefix.String())
|
||||
}
|
||||
|
||||
return &api.PeerNetworkRangeCheck{
|
||||
Ranges: netPrefixes,
|
||||
Action: api.PeerNetworkRangeCheckAction(check.Action),
|
||||
}
|
||||
}
|
||||
|
||||
func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNetworkRangeCheck, error) {
|
||||
prefixes := make([]netip.Prefix, 0)
|
||||
for _, prefix := range check.Ranges {
|
||||
parsedPrefix, err := netip.ParsePrefix(prefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prefixes = append(prefixes, parsedPrefix)
|
||||
}
|
||||
|
||||
return &posture.PeerNetworkRangeCheck{
|
||||
Ranges: prefixes,
|
||||
Action: string(check.Action),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||
data.Set("client_id", zc.clientConfig.ClientID)
|
||||
data.Set("client_secret", zc.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", zc.clientConfig.GrantType)
|
||||
data.Set("scope", "urn:zitadel:iam:org:project:id:zitadel:aud")
|
||||
data.Set("scope", "openid urn:zitadel:iam:org:project:id:zitadel:aud")
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, zc.clientConfig.TokenEndpoint, payload)
|
||||
|
||||
@@ -11,7 +11,7 @@ type IntegratedValidator interface {
|
||||
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool)
|
||||
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(accountID, peerID string) error
|
||||
SetPeerInvalidationListener(fn func(accountID string))
|
||||
|
||||
@@ -405,10 +405,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, err := NewStoreFromJson(config.Datadir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
|
||||
@@ -469,8 +469,8 @@ func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, p
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) {
|
||||
return false, false
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
|
||||
@@ -532,10 +532,11 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s := grpc.NewServer()
|
||||
|
||||
store, err := server.NewStoreFromJson(config.Datadir, nil)
|
||||
store, _, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
// defaultPushInterval default interval to push metrics
|
||||
defaultPushInterval = 24 * time.Hour
|
||||
// requestTimeout http request timeout
|
||||
requestTimeout = 30 * time.Second
|
||||
requestTimeout = 45 * time.Second
|
||||
)
|
||||
|
||||
type getTokenResponse struct {
|
||||
@@ -98,10 +98,7 @@ func (w *Worker) Run() {
|
||||
}
|
||||
|
||||
func (w *Worker) sendMetrics() error {
|
||||
ctx, cancel := context.WithTimeout(w.ctx, requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
apiKey, err := getAPIKey(ctx)
|
||||
apiKey, err := getAPIKey(w.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -115,7 +112,7 @@ func (w *Worker) sendMetrics() error {
|
||||
|
||||
httpClient := http.Client{}
|
||||
|
||||
exportJobReq, err := createPostRequest(ctx, payloadEndpoint+"/capture/", payloadString)
|
||||
exportJobReq, err := createPostRequest(w.ctx, payloadEndpoint+"/capture/", payloadString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create metrics post request %v", err)
|
||||
}
|
||||
@@ -328,6 +325,8 @@ func (w *Worker) generateProperties() properties {
|
||||
}
|
||||
|
||||
func getAPIKey(ctx context.Context) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
httpClient := http.Client{}
|
||||
|
||||
@@ -375,6 +374,8 @@ func buildMetricsPayload(payload pushPayload) (string, error) {
|
||||
}
|
||||
|
||||
func createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
|
||||
defer cancel()
|
||||
reqURL := endpoint
|
||||
|
||||
payload := strings.NewReader(payloadStr)
|
||||
|
||||
@@ -95,6 +95,7 @@ type MockAccountManager struct {
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
@@ -734,3 +735,11 @@ func (am *MockAccountManager) GroupValidation(accountId string, groups []string)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
|
||||
}
|
||||
|
||||
// FindExistingPostureCheck mocks FindExistingPostureCheck of the AccountManager interface
|
||||
func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
if am.FindExistingPostureCheckFunc != nil {
|
||||
return am.FindExistingPostureCheckFunc(accountID, checks)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented")
|
||||
}
|
||||
|
||||
@@ -766,10 +766,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -335,24 +335,29 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
|
||||
upperKey := strings.ToUpper(setupKey)
|
||||
var account *Account
|
||||
var accountID string
|
||||
var err error
|
||||
addedByUser := false
|
||||
if len(userID) > 0 {
|
||||
addedByUser = true
|
||||
account, err = am.Store.GetAccountByUser(userID)
|
||||
accountID, err = am.Store.GetAccountIDByUserID(userID)
|
||||
} else {
|
||||
account, err = am.Store.GetAccountBySetupKey(setupKey)
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(setupKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
var account *Account
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
account, err = am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -485,6 +490,10 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Account is saved, we can release the lock
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
opEvent.TargetID = newPeer.ID
|
||||
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
|
||||
if !addedByUser {
|
||||
@@ -507,7 +516,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
return nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
@@ -515,11 +524,15 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(peer, account) {
|
||||
if peerLoginExpired(peer, account.Settings) {
|
||||
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if peerNotValid {
|
||||
emptyMap := &NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
@@ -541,7 +554,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
|
||||
// LoginPeer logs in or registers a peer.
|
||||
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey)
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
|
||||
@@ -570,19 +583,59 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
return nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
accSettings, err := am.Store.GetAccountSettings(accountID)
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
|
||||
}
|
||||
|
||||
var isWriteLock bool
|
||||
|
||||
// duplicated logic from after the lock to have an early exit
|
||||
expired := peerLoginExpired(peer, accSettings)
|
||||
switch {
|
||||
case expired:
|
||||
if err := checkAuth(login.UserID, peer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
isWriteLock = true
|
||||
log.Debugf("peer login expired, acquiring write lock")
|
||||
|
||||
case peer.UpdateMetaIfNew(login.Meta):
|
||||
isWriteLock = true
|
||||
log.Debugf("peer changed meta, acquiring write lock")
|
||||
|
||||
default:
|
||||
isWriteLock = false
|
||||
log.Debugf("peer meta is the same, acquiring read lock")
|
||||
}
|
||||
|
||||
var unlock func()
|
||||
|
||||
if isWriteLock {
|
||||
unlock = am.Store.AcquireAccountWriteLock(accountID)
|
||||
} else {
|
||||
unlock = am.Store.AcquireAccountReadLock(accountID)
|
||||
}
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
|
||||
peer, err = account.FindPeerByPubKey(login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
return nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
@@ -593,7 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStoreAccount := false
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(peer, account) {
|
||||
if peerLoginExpired(peer, account.Settings) {
|
||||
err = checkAuth(login.UserID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -614,7 +667,10 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
peer, updated := updatePeerMeta(peer, login.Meta, account)
|
||||
if updated {
|
||||
shouldStoreAccount = true
|
||||
@@ -626,11 +682,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}
|
||||
|
||||
if shouldStoreAccount {
|
||||
if !isWriteLock {
|
||||
log.Errorf("account %s should be stored but is not write locked", accountID)
|
||||
return nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
|
||||
}
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(account)
|
||||
@@ -676,9 +738,9 @@ func checkAuth(loginUserID string, peer *nbpeer.Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func peerLoginExpired(peer *nbpeer.Peer, account *Account) bool {
|
||||
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
|
||||
expired = account.Settings.PeerLoginExpirationEnabled && expired
|
||||
func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool {
|
||||
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
expired = settings.PeerLoginExpirationEnabled && expired
|
||||
if expired || peer.Status.LoginExpired {
|
||||
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
|
||||
return true
|
||||
|
||||
@@ -216,7 +216,9 @@ func (p *Peer) FQDN(dnsDomain string) string {
|
||||
|
||||
// EventMeta returns activity event meta related to the peer
|
||||
func (p *Peer) EventMeta(dnsDomain string) map[string]any {
|
||||
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt}
|
||||
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt,
|
||||
"location_city_name": p.Location.CityName, "location_country_code": p.Location.CountryCode,
|
||||
"location_geo_name_id": p.Location.GeoNameID, "location_connection_ip": p.Location.ConnectionIP}
|
||||
}
|
||||
|
||||
// Copy PeerStatus
|
||||
|
||||
@@ -148,7 +148,7 @@ type Policy struct {
|
||||
Enabled bool
|
||||
|
||||
// Rules of the policy
|
||||
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id"`
|
||||
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// SourcePostureChecks are ID references to Posture checks for policy source groups
|
||||
SourcePostureChecks []string `gorm:"serializer:json"`
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -136,6 +139,96 @@ func (pc *Checks) GetChecks() []Check {
|
||||
return checks
|
||||
}
|
||||
|
||||
func NewChecksFromAPIPostureCheck(source api.PostureCheck) (*Checks, error) {
|
||||
description := ""
|
||||
if source.Description != nil {
|
||||
description = *source.Description
|
||||
}
|
||||
|
||||
return buildPostureCheck(source.Id, source.Name, description, source.Checks)
|
||||
}
|
||||
|
||||
func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureChecksID string) (*Checks, error) {
|
||||
return buildPostureCheck(postureChecksID, source.Name, source.Description, *source.Checks)
|
||||
}
|
||||
|
||||
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
||||
if postureChecksID == "" {
|
||||
postureChecksID = xid.New().String()
|
||||
}
|
||||
|
||||
postureChecks := Checks{
|
||||
ID: postureChecksID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
if nbVersionCheck := checks.NbVersionCheck; nbVersionCheck != nil {
|
||||
postureChecks.Checks.NBVersionCheck = &NBVersionCheck{
|
||||
MinVersion: nbVersionCheck.MinVersion,
|
||||
}
|
||||
}
|
||||
|
||||
if osVersionCheck := checks.OsVersionCheck; osVersionCheck != nil {
|
||||
postureChecks.Checks.OSVersionCheck = &OSVersionCheck{
|
||||
Android: (*MinVersionCheck)(osVersionCheck.Android),
|
||||
Darwin: (*MinVersionCheck)(osVersionCheck.Darwin),
|
||||
Ios: (*MinVersionCheck)(osVersionCheck.Ios),
|
||||
Linux: (*MinKernelVersionCheck)(osVersionCheck.Linux),
|
||||
Windows: (*MinKernelVersionCheck)(osVersionCheck.Windows),
|
||||
}
|
||||
}
|
||||
|
||||
if geoLocationCheck := checks.GeoLocationCheck; geoLocationCheck != nil {
|
||||
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
|
||||
}
|
||||
|
||||
var err error
|
||||
if peerNetworkRangeCheck := checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
|
||||
postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid network prefix")
|
||||
}
|
||||
}
|
||||
|
||||
return &postureChecks, nil
|
||||
}
|
||||
|
||||
func (pc *Checks) ToAPIResponse() *api.PostureCheck {
|
||||
var checks api.Checks
|
||||
|
||||
if pc.Checks.NBVersionCheck != nil {
|
||||
checks.NbVersionCheck = &api.NBVersionCheck{
|
||||
MinVersion: pc.Checks.NBVersionCheck.MinVersion,
|
||||
}
|
||||
}
|
||||
|
||||
if pc.Checks.OSVersionCheck != nil {
|
||||
checks.OsVersionCheck = &api.OSVersionCheck{
|
||||
Android: (*api.MinVersionCheck)(pc.Checks.OSVersionCheck.Android),
|
||||
Darwin: (*api.MinVersionCheck)(pc.Checks.OSVersionCheck.Darwin),
|
||||
Ios: (*api.MinVersionCheck)(pc.Checks.OSVersionCheck.Ios),
|
||||
Linux: (*api.MinKernelVersionCheck)(pc.Checks.OSVersionCheck.Linux),
|
||||
Windows: (*api.MinKernelVersionCheck)(pc.Checks.OSVersionCheck.Windows),
|
||||
}
|
||||
}
|
||||
|
||||
if pc.Checks.GeoLocationCheck != nil {
|
||||
checks.GeoLocationCheck = toGeoLocationCheckResponse(pc.Checks.GeoLocationCheck)
|
||||
}
|
||||
|
||||
if pc.Checks.PeerNetworkRangeCheck != nil {
|
||||
checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(pc.Checks.PeerNetworkRangeCheck)
|
||||
}
|
||||
|
||||
return &api.PostureCheck{
|
||||
Id: pc.ID,
|
||||
Name: pc.Name,
|
||||
Description: &pc.Description,
|
||||
Checks: checks,
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *Checks) Validate() error {
|
||||
if check := pc.Checks.NBVersionCheck; check != nil {
|
||||
if !isVersionValid(check.MinVersion) {
|
||||
@@ -192,3 +285,70 @@ func isVersionValid(ver string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func toGeoLocationCheckResponse(geoLocationCheck *GeoLocationCheck) *api.GeoLocationCheck {
|
||||
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
|
||||
for _, loc := range geoLocationCheck.Locations {
|
||||
l := loc // make G601 happy
|
||||
var cityName *string
|
||||
if loc.CityName != "" {
|
||||
cityName = &l.CityName
|
||||
}
|
||||
locations = append(locations, api.Location{
|
||||
CityName: cityName,
|
||||
CountryCode: loc.CountryCode,
|
||||
})
|
||||
}
|
||||
|
||||
return &api.GeoLocationCheck{
|
||||
Action: api.GeoLocationCheckAction(geoLocationCheck.Action),
|
||||
Locations: locations,
|
||||
}
|
||||
}
|
||||
|
||||
func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *GeoLocationCheck {
|
||||
locations := make([]Location, 0, len(apiGeoLocationCheck.Locations))
|
||||
for _, loc := range apiGeoLocationCheck.Locations {
|
||||
cityName := ""
|
||||
if loc.CityName != nil {
|
||||
cityName = *loc.CityName
|
||||
}
|
||||
locations = append(locations, Location{
|
||||
CountryCode: loc.CountryCode,
|
||||
CityName: cityName,
|
||||
})
|
||||
}
|
||||
|
||||
return &GeoLocationCheck{
|
||||
Action: string(apiGeoLocationCheck.Action),
|
||||
Locations: locations,
|
||||
}
|
||||
}
|
||||
|
||||
func toPeerNetworkRangeCheckResponse(check *PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
|
||||
netPrefixes := make([]string, 0, len(check.Ranges))
|
||||
for _, netPrefix := range check.Ranges {
|
||||
netPrefixes = append(netPrefixes, netPrefix.String())
|
||||
}
|
||||
|
||||
return &api.PeerNetworkRangeCheck{
|
||||
Ranges: netPrefixes,
|
||||
Action: api.PeerNetworkRangeCheckAction(check.Action),
|
||||
}
|
||||
}
|
||||
|
||||
func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*PeerNetworkRangeCheck, error) {
|
||||
prefixes := make([]netip.Prefix, 0)
|
||||
for _, prefix := range check.Ranges {
|
||||
parsedPrefix, err := netip.ParsePrefix(prefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prefixes = append(prefixes, parsedPrefix)
|
||||
}
|
||||
|
||||
return &PeerNetworkRangeCheck{
|
||||
Ranges: prefixes,
|
||||
Action: string(check.Action),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3,8 +3,9 @@ package server
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -1021,10 +1021,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewStoreFromJson(dataDir, nil)
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -28,14 +27,14 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk
|
||||
type SqliteStore struct {
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
storeFile string
|
||||
accountLocks sync.Map
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
storeEngine StoreEngine
|
||||
}
|
||||
|
||||
type installation struct {
|
||||
@@ -45,24 +44,8 @@ type installation struct {
|
||||
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
// NewSqliteStore restores a store from the file located in the datadir
|
||||
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
|
||||
storeStr := "store.db?cache=shared"
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = "store.db"
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// NewSqlStore creates a new SqlStore instance.
|
||||
func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
sql, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -82,33 +65,11 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
|
||||
return nil, fmt.Errorf("auto migrate: %w", err)
|
||||
}
|
||||
|
||||
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
|
||||
}
|
||||
|
||||
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir
|
||||
func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
|
||||
store, err := NewSqliteStore(dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(filestore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range filestore.GetAllAccounts() {
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil
|
||||
}
|
||||
|
||||
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
||||
func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
||||
func (s *SqlStore) AcquireGlobalLock() (unlock func()) {
|
||||
log.Tracef("acquiring global lock")
|
||||
start := time.Now()
|
||||
s.globalAccountLock.Lock()
|
||||
@@ -127,7 +88,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring write lock for account %s", accountID)
|
||||
|
||||
start := time.Now()
|
||||
@@ -143,7 +104,7 @@ func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func())
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring read lock for account %s", accountID)
|
||||
|
||||
start := time.Now()
|
||||
@@ -159,7 +120,7 @@ func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||
func (s *SqlStore) SaveAccount(account *Account) error {
|
||||
start := time.Now()
|
||||
|
||||
for _, key := range account.SetupKeys {
|
||||
@@ -225,12 +186,12 @@ func (s *SqliteStore) SaveAccount(account *Account) error {
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
|
||||
log.Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||
func (s *SqlStore) DeleteAccount(account *Account) error {
|
||||
start := time.Now()
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
@@ -256,19 +217,19 @@ func (s *SqliteStore) DeleteAccount(account *Account) error {
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to delete an account to the SQLite", took.Milliseconds())
|
||||
log.Debugf("took %d ms to delete an account to the store", took.Milliseconds())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SaveInstallationID(ID string) error {
|
||||
func (s *SqlStore) SaveInstallationID(ID string) error {
|
||||
installation := installation{InstallationIDValue: ID}
|
||||
installation.ID = uint(s.installationPK)
|
||||
|
||||
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetInstallationID() string {
|
||||
func (s *SqlStore) GetInstallationID() string {
|
||||
var installation installation
|
||||
|
||||
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
|
||||
@@ -278,7 +239,7 @@ func (s *SqliteStore) GetInstallationID() string {
|
||||
return installation.InstallationIDValue
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
@@ -296,7 +257,7 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
@@ -318,17 +279,17 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
|
||||
func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex is noop in Sqlite
|
||||
func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
// DeleteTokenID2UserIDIndex is noop in SqlStore
|
||||
func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
var account Account
|
||||
|
||||
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
|
||||
@@ -345,7 +306,7 @@ func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
return s.GetAccount(account.Id)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
var key SetupKey
|
||||
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
|
||||
if result.Error != nil {
|
||||
@@ -363,7 +324,7 @@ func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
return s.GetAccount(key.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
@@ -377,7 +338,7 @@ func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "id = ?", tokenID)
|
||||
if result.Error != nil {
|
||||
@@ -406,7 +367,7 @@ func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
||||
func (s *SqlStore) GetAllAccounts() (all []*Account) {
|
||||
var accounts []Account
|
||||
result := s.db.Find(&accounts)
|
||||
if result.Error != nil {
|
||||
@@ -422,7 +383,7 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
||||
return all
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccount(accountID string) (*Account, error) {
|
||||
|
||||
var account Account
|
||||
result := s.db.Model(&account).
|
||||
@@ -430,7 +391,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||
Preload(clause.Associations).
|
||||
First(&account, "id = ?", accountID)
|
||||
if result.Error != nil {
|
||||
log.Errorf("error when getting account from the store: %s", result.Error)
|
||||
log.Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
@@ -490,14 +451,13 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
var user User
|
||||
result := s.db.Select("account_id").First(&user, "id = ?", userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.Errorf("error when getting user from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
@@ -508,7 +468,7 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
return s.GetAccount(user.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
|
||||
if result.Error != nil {
|
||||
@@ -526,7 +486,7 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
return s.GetAccount(peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
var peer nbpeer.Peer
|
||||
|
||||
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
|
||||
@@ -545,7 +505,7 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return s.GetAccount(peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||
@@ -560,8 +520,63 @@ func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
var user User
|
||||
var accountID string
|
||||
result := s.db.Model(&user).Select("account_id").Where("id = ?", userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
||||
var key SetupKey
|
||||
var accountID string
|
||||
result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.Errorf("error when getting setup key from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting setup key from store")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.First(&peer, "key = ?", peerKey)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
||||
}
|
||||
|
||||
return &peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) {
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.Model(&Account{}).Where("id = ?", accountID).First(&accountSettings).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
log.Errorf("error when getting settings from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
||||
}
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
|
||||
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
|
||||
@@ -569,7 +584,6 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
}
|
||||
log.Errorf("error when getting user from the store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
@@ -578,8 +592,23 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time
|
||||
return s.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
|
||||
definitionJSON, err := json.Marshal(checks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var postureCheck posture.Checks
|
||||
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &postureCheck, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying DB connection
|
||||
func (s *SqliteStore) Close() error {
|
||||
func (s *SqlStore) Close() error {
|
||||
sql, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get db: %w", err)
|
||||
@@ -587,40 +616,85 @@ func (s *SqliteStore) Close() error {
|
||||
return sql.Close()
|
||||
}
|
||||
|
||||
// GetStoreEngine returns SqliteStoreEngine
|
||||
func (s *SqliteStore) GetStoreEngine() StoreEngine {
|
||||
return SqliteStoreEngine
|
||||
// GetStoreEngine returns underlying store engine
|
||||
func (s *SqlStore) GetStoreEngine() StoreEngine {
|
||||
return s.storeEngine
|
||||
}
|
||||
|
||||
// migrate migrates the SQLite database to the latest schema
|
||||
func migrate(db *gorm.DB) error {
|
||||
migrations := getMigrations()
|
||||
// NewSqliteStore creates a new SQLite store.
|
||||
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
storeStr := "store.db?cache=shared"
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = "store.db"
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(db, SqliteStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewPostgresqlStore creates a new Postgres store.
|
||||
func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSqlStore(db, PostgresStoreEngine, metrics)
|
||||
}
|
||||
|
||||
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
|
||||
func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewSqliteStore(dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(fileStore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range fileStore.GetAllAccounts() {
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func getMigrations() []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB.
|
||||
func NewPostgresqlStoreFromFileStore(fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
store, err := NewPostgresqlStore(dsn, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(fileStore.InstallationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range fileStore.GetAllAccounts() {
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return store, nil
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -569,7 +571,7 @@ func TestMigrate(t *testing.T) {
|
||||
require.NoError(t, err, "Migration should not fail on migrated db")
|
||||
}
|
||||
|
||||
func newSqliteStore(t *testing.T) *SqliteStore {
|
||||
func newSqliteStore(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
store, err := NewSqliteStore(t.TempDir(), nil)
|
||||
@@ -579,7 +581,7 @@ func newSqliteStore(t *testing.T) *SqliteStore {
|
||||
return store
|
||||
}
|
||||
|
||||
func newSqliteStoreFromFile(t *testing.T, filename string) *SqliteStore {
|
||||
func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
storeDir := t.TempDir()
|
||||
@@ -613,3 +615,298 @@ func newAccount(store Store, id int) error {
|
||||
|
||||
return store.SaveAccount(account)
|
||||
}
|
||||
|
||||
func newPostgresqlStore(t *testing.T) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
cleanUp, err := testutil.CreatePGDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStore(postgresDsn, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("could not initialize postgresql store: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore {
|
||||
t.Helper()
|
||||
|
||||
storeDir := t.TempDir()
|
||||
err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fStore, err := NewFileStore(storeDir, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanUp, err := testutil.CreatePGDB()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, store)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func TestPostgresql_NewStore(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
|
||||
if len(store.GetAllAccounts()) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account2 := newAccountWithId("account_id2", "testuser2", "")
|
||||
setupKey = GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
SetupKey: "peerkeysetupkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 2 {
|
||||
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies) != 1 {
|
||||
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||
return
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil {
|
||||
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByUser("testuser"); a == nil {
|
||||
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByPeerID("testpeer"); a == nil {
|
||||
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil {
|
||||
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStore(t)
|
||||
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
}}
|
||||
|
||||
account := newAccountWithId("account_id", testUserID, "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 1 {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
err = store.DeleteAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts()) != 0 {
|
||||
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
|
||||
}
|
||||
|
||||
_, err = store.GetAccountByPeerPubKey("peerkey")
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key")
|
||||
|
||||
_, err = store.GetAccountByUser("testuser")
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user")
|
||||
|
||||
_, err = store.GetAccountByPeerID("testpeer")
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id")
|
||||
|
||||
_, err = store.GetAccountBySetupKey(setupKey.Key)
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key")
|
||||
|
||||
_, err = store.GetAccount(account.Id)
|
||||
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
var rules []*PolicyRule
|
||||
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
|
||||
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
|
||||
|
||||
}
|
||||
|
||||
for _, accountUser := range account.Users {
|
||||
var pats []*PersonalAccessToken
|
||||
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
|
||||
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
|
||||
// save new status of existing peer
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
ID: "testpeer",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers["testpeer"].Status
|
||||
assert.Equal(t, newStatus.Connected, actual.Connected)
|
||||
}
|
||||
|
||||
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
existingDomain := "test.com"
|
||||
|
||||
account, err := store.GetAccountByPrivateDomain(existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
}
|
||||
|
||||
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||
}
|
||||
|
||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
@@ -75,3 +75,23 @@ func FromError(err error) (s *Error, ok bool) {
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// NewPeerNotFoundError creates a new Error with NotFound type for a missing peer
|
||||
func NewPeerNotFoundError(peerKey string) error {
|
||||
return Errorf(NotFound, "peer not found: %s", peerKey)
|
||||
}
|
||||
|
||||
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account
|
||||
func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
func NewUserNotFoundError(userKey string) error {
|
||||
return Errorf(NotFound, "user not found: %s", userKey)
|
||||
}
|
||||
|
||||
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
||||
func NewPeerNotRegisteredError() error {
|
||||
return Errorf(Unauthenticated, "peer is not registered")
|
||||
}
|
||||
|
||||
@@ -2,15 +2,22 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
@@ -20,11 +27,14 @@ type Store interface {
|
||||
GetAccountByUser(userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||
GetAccountIDByPeerPubKey(peerKey string) (string, error)
|
||||
GetAccountIDByUserID(peerKey string) (string, error)
|
||||
GetAccountIDBySetupKey(peerKey string) (string, error)
|
||||
GetAccountByPeerID(peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(secret string) (string, error)
|
||||
GetUserByTokenID(tokenID string) (*User, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(account *Account) error
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
@@ -44,13 +54,18 @@ type Store interface {
|
||||
// GetStoreEngine should return StoreEngine of the current store implementation.
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountSettings(accountID string) (*Settings, error)
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
|
||||
const (
|
||||
FileStoreEngine StoreEngine = "jsonfile"
|
||||
SqliteStoreEngine StoreEngine = "sqlite"
|
||||
FileStoreEngine StoreEngine = "jsonfile"
|
||||
SqliteStoreEngine StoreEngine = "sqlite"
|
||||
PostgresStoreEngine StoreEngine = "postgres"
|
||||
|
||||
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
)
|
||||
|
||||
func getStoreEngineFromEnv() StoreEngine {
|
||||
@@ -61,8 +76,7 @@ func getStoreEngineFromEnv() StoreEngine {
|
||||
}
|
||||
|
||||
value := StoreEngine(strings.ToLower(kind))
|
||||
|
||||
if value == FileStoreEngine || value == SqliteStoreEngine {
|
||||
if value == FileStoreEngine || value == SqliteStoreEngine || value == PostgresStoreEngine {
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -94,18 +108,60 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S
|
||||
case SqliteStoreEngine:
|
||||
log.Info("using SQLite store engine")
|
||||
return NewSqliteStore(dataDir, metrics)
|
||||
case PostgresStoreEngine:
|
||||
log.Info("using Postgres store engine")
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
return NewPostgresqlStore(dsn, metrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported kind of store %s", kind)
|
||||
}
|
||||
}
|
||||
|
||||
// NewStoreFromJson is only used in tests
|
||||
func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) {
|
||||
// migrate migrates the SQLite database to the latest schema
|
||||
func migrate(db *gorm.DB) error {
|
||||
migrations := getMigrations()
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMigrations() []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTestStoreFromJson is only used in tests
|
||||
func NewTestStoreFromJson(dataDir string) (Store, func(), error) {
|
||||
fstore, err := NewFileStore(dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanUp := func() {}
|
||||
|
||||
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
|
||||
kind := getStoreEngineFromEnv()
|
||||
if kind == "" {
|
||||
@@ -115,10 +171,34 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro
|
||||
|
||||
switch kind {
|
||||
case FileStoreEngine:
|
||||
return fstore, nil
|
||||
return fstore, cleanUp, nil
|
||||
case SqliteStoreEngine:
|
||||
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
|
||||
store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return store, cleanUp, nil
|
||||
case PostgresStoreEngine:
|
||||
cleanUp, err = testutil.CreatePGDB()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err := NewPostgresqlStoreFromFileStore(fstore, dsn, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return store, cleanUp, nil
|
||||
default:
|
||||
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
|
||||
store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return store, cleanUp, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,50 +5,49 @@ import (
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/metric/instrument"
|
||||
"go.opentelemetry.io/otel/metric/instrument/asyncint64"
|
||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||
)
|
||||
|
||||
// GRPCMetrics are gRPC server metrics
|
||||
type GRPCMetrics struct {
|
||||
meter metric.Meter
|
||||
syncRequestsCounter syncint64.Counter
|
||||
loginRequestsCounter syncint64.Counter
|
||||
getKeyRequestsCounter syncint64.Counter
|
||||
activeStreamsGauge asyncint64.Gauge
|
||||
syncRequestDuration syncint64.Histogram
|
||||
loginRequestDuration syncint64.Histogram
|
||||
channelQueueLength syncint64.Histogram
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
|
||||
func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, error) {
|
||||
syncRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.sync.request.counter", instrument.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loginRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.login.request.counter", instrument.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getKeyRequestsCounter, err := meter.SyncInt64().Counter("management.grpc.key.request.counter", instrument.WithUnit("1"))
|
||||
syncRequestsCounter, err := meter.Int64Counter("management.grpc.sync.request.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
activeStreamsGauge, err := meter.AsyncInt64().Gauge("management.grpc.connected.streams", instrument.WithUnit("1"))
|
||||
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncRequestDuration, err := meter.SyncInt64().Histogram("management.grpc.sync.request.duration.ms", instrument.WithUnit("milliseconds"))
|
||||
getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestDuration, err := meter.SyncInt64().Histogram("management.grpc.login.request.duration.ms", instrument.WithUnit("milliseconds"))
|
||||
activeStreamsGauge, err := meter.Int64ObservableGauge("management.grpc.connected.streams", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncRequestDuration, err := meter.Int64Histogram("management.grpc.sync.request.duration.ms", metric.WithUnit("milliseconds"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms", metric.WithUnit("milliseconds"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -56,10 +55,10 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
|
||||
// Then we should be able to extract min, manx, mean and the percentiles.
|
||||
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
|
||||
channelQueue, err := meter.SyncInt64().Histogram(
|
||||
channelQueue, err := meter.Int64Histogram(
|
||||
"management.grpc.updatechannel.queue",
|
||||
instrument.WithDescription("Number of update messages in the channel queue"),
|
||||
instrument.WithUnit("length"),
|
||||
metric.WithDescription("Number of update messages in the channel queue"),
|
||||
metric.WithUnit("length"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -105,14 +104,14 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration)
|
||||
|
||||
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
||||
func (grpcMetrics *GRPCMetrics) RegisterConnectedStreams(producer func() int64) error {
|
||||
return grpcMetrics.meter.RegisterCallback(
|
||||
[]instrument.Asynchronous{
|
||||
grpcMetrics.activeStreamsGauge,
|
||||
},
|
||||
func(ctx context.Context) {
|
||||
grpcMetrics.activeStreamsGauge.Observe(ctx, producer())
|
||||
_, err := grpcMetrics.meter.RegisterCallback(
|
||||
func(ctx context.Context, observer metric.Observer) error {
|
||||
observer.ObserveInt64(grpcMetrics.activeStreamsGauge, producer())
|
||||
return nil
|
||||
},
|
||||
grpcMetrics.activeStreamsGauge,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateChannelQueueLength update the histogram that keep distribution of the update messages channel queue
|
||||
|
||||
@@ -6,13 +6,11 @@ import (
|
||||
"hash/fnv"
|
||||
"net/http"
|
||||
"strings"
|
||||
time "time"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/metric/instrument"
|
||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -56,51 +54,44 @@ type HTTPMiddleware struct {
|
||||
meter metric.Meter
|
||||
ctx context.Context
|
||||
// all HTTP requests by endpoint & method
|
||||
httpRequestCounters map[string]syncint64.Counter
|
||||
httpRequestCounters map[string]metric.Int64Counter
|
||||
// all HTTP responses by endpoint & method & status code
|
||||
httpResponseCounters map[string]syncint64.Counter
|
||||
httpResponseCounters map[string]metric.Int64Counter
|
||||
// all HTTP requests
|
||||
totalHTTPRequestsCounter syncint64.Counter
|
||||
totalHTTPRequestsCounter metric.Int64Counter
|
||||
// all HTTP responses
|
||||
totalHTTPResponseCounter syncint64.Counter
|
||||
totalHTTPResponseCounter metric.Int64Counter
|
||||
// all HTTP responses by status code
|
||||
totalHTTPResponseCodeCounters map[int]syncint64.Counter
|
||||
totalHTTPResponseCodeCounters map[int]metric.Int64Counter
|
||||
// all HTTP requests durations by endpoint and method
|
||||
httpRequestDurations map[string]syncint64.Histogram
|
||||
httpRequestDurations map[string]metric.Int64Histogram
|
||||
// all HTTP requests durations
|
||||
totalHTTPRequestDuration syncint64.Histogram
|
||||
totalHTTPRequestDuration metric.Int64Histogram
|
||||
}
|
||||
|
||||
// NewMetricsMiddleware creates a new HTTPMiddleware
|
||||
func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) {
|
||||
|
||||
totalHTTPRequestsCounter, err := meter.SyncInt64().Counter(
|
||||
fmt.Sprintf("%s_total", httpRequestCounterPrefix),
|
||||
instrument.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totalHTTPResponseCounter, err := meter.SyncInt64().Counter(
|
||||
fmt.Sprintf("%s_total", httpResponseCounterPrefix),
|
||||
instrument.WithUnit("1"))
|
||||
totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpRequestCounterPrefix), metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalHTTPRequestDuration, err := meter.SyncInt64().Histogram(
|
||||
fmt.Sprintf("%s_total", httpRequestDurationPrefix),
|
||||
instrument.WithUnit("milliseconds"))
|
||||
totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpResponseCounterPrefix), metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s_total", httpRequestDurationPrefix), metric.WithUnit("milliseconds"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HTTPMiddleware{
|
||||
ctx: ctx,
|
||||
httpRequestCounters: map[string]syncint64.Counter{},
|
||||
httpResponseCounters: map[string]syncint64.Counter{},
|
||||
httpRequestDurations: map[string]syncint64.Histogram{},
|
||||
totalHTTPResponseCodeCounters: map[int]syncint64.Counter{},
|
||||
httpRequestCounters: map[string]metric.Int64Counter{},
|
||||
httpResponseCounters: map[string]metric.Int64Counter{},
|
||||
httpRequestDurations: map[string]metric.Int64Histogram{},
|
||||
totalHTTPResponseCodeCounters: map[int]metric.Int64Counter{},
|
||||
meter: meter,
|
||||
totalHTTPRequestsCounter: totalHTTPRequestsCounter,
|
||||
totalHTTPResponseCounter: totalHTTPResponseCounter,
|
||||
@@ -113,28 +104,30 @@ func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddlew
|
||||
// Creates one request counter and multiple response counters (one per http response status code).
|
||||
func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error {
|
||||
meterKey := getRequestCounterKey(endpoint, method)
|
||||
httpReqCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
|
||||
httpReqCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.httpRequestCounters[meterKey] = httpReqCounter
|
||||
|
||||
durationKey := getRequestDurationKey(endpoint, method)
|
||||
requestDuration, err := m.meter.SyncInt64().Histogram(durationKey, instrument.WithUnit("milliseconds"))
|
||||
requestDuration, err := m.meter.Int64Histogram(durationKey, metric.WithUnit("milliseconds"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.httpRequestDurations[durationKey] = requestDuration
|
||||
|
||||
respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503}
|
||||
for _, code := range respCodes {
|
||||
meterKey = getResponseCounterKey(endpoint, method, code)
|
||||
httpRespCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
|
||||
httpRespCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.httpResponseCounters[meterKey] = httpRespCounter
|
||||
|
||||
meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code)
|
||||
totalHTTPResponseCodeCounter, err := m.meter.SyncInt64().Counter(meterKey, instrument.WithUnit("1"))
|
||||
totalHTTPResponseCodeCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -144,19 +137,26 @@ func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method s
|
||||
return nil
|
||||
}
|
||||
|
||||
func replaceEndpointChars(endpoint string) string {
|
||||
endpoint = strings.ReplaceAll(endpoint, "/", "_")
|
||||
endpoint = strings.ReplaceAll(endpoint, "{", "")
|
||||
endpoint = strings.ReplaceAll(endpoint, "}", "")
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func getRequestCounterKey(endpoint, method string) string {
|
||||
return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix,
|
||||
strings.ReplaceAll(endpoint, "/", "_"), method)
|
||||
endpoint = replaceEndpointChars(endpoint)
|
||||
return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix, endpoint, method)
|
||||
}
|
||||
|
||||
func getRequestDurationKey(endpoint, method string) string {
|
||||
return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix,
|
||||
strings.ReplaceAll(endpoint, "/", "_"), method)
|
||||
endpoint = replaceEndpointChars(endpoint)
|
||||
return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix, endpoint, method)
|
||||
}
|
||||
|
||||
func getResponseCounterKey(endpoint, method string, status int) string {
|
||||
return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix,
|
||||
strings.ReplaceAll(endpoint, "/", "_"), method, status)
|
||||
endpoint = replaceEndpointChars(endpoint)
|
||||
return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix, endpoint, method, status)
|
||||
}
|
||||
|
||||
// Handler logs every request and response and adds the, to metrics.
|
||||
@@ -201,9 +201,11 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
log.Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status())
|
||||
|
||||
if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) {
|
||||
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), attribute.String("type", "write"))
|
||||
opts := metric.WithAttributeSet(attribute.NewSet(attribute.String("type", "write")))
|
||||
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), opts)
|
||||
} else {
|
||||
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), attribute.String("type", "read"))
|
||||
opts := metric.WithAttributeSet(attribute.NewSet(attribute.String("type", "read")))
|
||||
m.totalHTTPRequestDuration.Record(m.ctx, reqTook.Milliseconds(), opts)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,64 +4,62 @@ import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/metric/instrument"
|
||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||
)
|
||||
|
||||
// IDPMetrics is common IdP metrics
|
||||
type IDPMetrics struct {
|
||||
metaUpdateCounter syncint64.Counter
|
||||
getUserByEmailCounter syncint64.Counter
|
||||
getAllAccountsCounter syncint64.Counter
|
||||
createUserCounter syncint64.Counter
|
||||
deleteUserCounter syncint64.Counter
|
||||
getAccountCounter syncint64.Counter
|
||||
getUserByIDCounter syncint64.Counter
|
||||
authenticateRequestCounter syncint64.Counter
|
||||
requestErrorCounter syncint64.Counter
|
||||
requestStatusErrorCounter syncint64.Counter
|
||||
metaUpdateCounter metric.Int64Counter
|
||||
getUserByEmailCounter metric.Int64Counter
|
||||
getAllAccountsCounter metric.Int64Counter
|
||||
createUserCounter metric.Int64Counter
|
||||
deleteUserCounter metric.Int64Counter
|
||||
getAccountCounter metric.Int64Counter
|
||||
getUserByIDCounter metric.Int64Counter
|
||||
authenticateRequestCounter metric.Int64Counter
|
||||
requestErrorCounter metric.Int64Counter
|
||||
requestStatusErrorCounter metric.Int64Counter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewIDPMetrics creates new IDPMetrics struct and registers common
|
||||
func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error) {
|
||||
metaUpdateCounter, err := meter.SyncInt64().Counter("management.idp.update.user.meta.counter", instrument.WithUnit("1"))
|
||||
metaUpdateCounter, err := meter.Int64Counter("management.idp.update.user.meta.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getUserByEmailCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.email.counter", instrument.WithUnit("1"))
|
||||
getUserByEmailCounter, err := meter.Int64Counter("management.idp.get.user.by.email.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getAllAccountsCounter, err := meter.SyncInt64().Counter("management.idp.get.accounts.counter", instrument.WithUnit("1"))
|
||||
getAllAccountsCounter, err := meter.Int64Counter("management.idp.get.accounts.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
createUserCounter, err := meter.SyncInt64().Counter("management.idp.create.user.counter", instrument.WithUnit("1"))
|
||||
createUserCounter, err := meter.Int64Counter("management.idp.create.user.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1"))
|
||||
deleteUserCounter, err := meter.Int64Counter("management.idp.delete.user.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
|
||||
getAccountCounter, err := meter.Int64Counter("management.idp.get.account.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getUserByIDCounter, err := meter.SyncInt64().Counter("management.idp.get.user.by.id.counter", instrument.WithUnit("1"))
|
||||
getUserByIDCounter, err := meter.Int64Counter("management.idp.get.user.by.id.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authenticateRequestCounter, err := meter.SyncInt64().Counter("management.idp.authenticate.request.counter", instrument.WithUnit("1"))
|
||||
authenticateRequestCounter, err := meter.Int64Counter("management.idp.authenticate.request.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.error.counter", instrument.WithUnit("1"))
|
||||
requestErrorCounter, err := meter.Int64Counter("management.idp.request.error.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestStatusErrorCounter, err := meter.SyncInt64().Counter("management.idp.request.status.error.counter", instrument.WithUnit("1"))
|
||||
requestStatusErrorCounter, err := meter.Int64Counter("management.idp.request.status.error.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5,39 +5,37 @@ import (
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/metric/instrument"
|
||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||
)
|
||||
|
||||
// StoreMetrics represents all metrics related to the Store
|
||||
type StoreMetrics struct {
|
||||
globalLockAcquisitionDurationMicro syncint64.Histogram
|
||||
globalLockAcquisitionDurationMs syncint64.Histogram
|
||||
persistenceDurationMicro syncint64.Histogram
|
||||
persistenceDurationMs syncint64.Histogram
|
||||
globalLockAcquisitionDurationMicro metric.Int64Histogram
|
||||
globalLockAcquisitionDurationMs metric.Int64Histogram
|
||||
persistenceDurationMicro metric.Int64Histogram
|
||||
persistenceDurationMs metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewStoreMetrics creates an instance of StoreMetrics
|
||||
func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, error) {
|
||||
globalLockAcquisitionDurationMicro, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.micro",
|
||||
instrument.WithUnit("microseconds"))
|
||||
globalLockAcquisitionDurationMicro, err := meter.Int64Histogram("management.store.global.lock.acquisition.duration.micro",
|
||||
metric.WithUnit("microseconds"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
globalLockAcquisitionDurationMs, err := meter.SyncInt64().Histogram("management.store.global.lock.acquisition.duration.ms")
|
||||
globalLockAcquisitionDurationMs, err := meter.Int64Histogram("management.store.global.lock.acquisition.duration.ms")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistenceDurationMicro, err := meter.SyncInt64().Histogram("management.store.persistence.duration.micro",
|
||||
instrument.WithUnit("microseconds"))
|
||||
persistenceDurationMicro, err := meter.Int64Histogram("management.store.persistence.duration.micro",
|
||||
metric.WithUnit("microseconds"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistenceDurationMs, err := meter.SyncInt64().Histogram("management.store.persistence.duration.ms")
|
||||
persistenceDurationMs, err := meter.Int64Histogram("management.store.persistence.duration.ms")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -6,60 +6,59 @@ import (
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||
)
|
||||
|
||||
// UpdateChannelMetrics represents all metrics related to the UpdateChannel
|
||||
type UpdateChannelMetrics struct {
|
||||
createChannelDurationMicro syncint64.Histogram
|
||||
closeChannelDurationMicro syncint64.Histogram
|
||||
closeChannelsDurationMicro syncint64.Histogram
|
||||
closeChannels syncint64.Histogram
|
||||
sendUpdateDurationMicro syncint64.Histogram
|
||||
getAllConnectedPeersDurationMicro syncint64.Histogram
|
||||
getAllConnectedPeers syncint64.Histogram
|
||||
hasChannelDurationMicro syncint64.Histogram
|
||||
createChannelDurationMicro metric.Int64Histogram
|
||||
closeChannelDurationMicro metric.Int64Histogram
|
||||
closeChannelsDurationMicro metric.Int64Histogram
|
||||
closeChannels metric.Int64Histogram
|
||||
sendUpdateDurationMicro metric.Int64Histogram
|
||||
getAllConnectedPeersDurationMicro metric.Int64Histogram
|
||||
getAllConnectedPeers metric.Int64Histogram
|
||||
hasChannelDurationMicro metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewUpdateChannelMetrics creates an instance of UpdateChannel
|
||||
func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateChannelMetrics, error) {
|
||||
createChannelDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.create.duration.micro")
|
||||
createChannelDurationMicro, err := meter.Int64Histogram("management.updatechannel.create.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
closeChannelDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.close.one.duration.micro")
|
||||
closeChannelDurationMicro, err := meter.Int64Histogram("management.updatechannel.close.one.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
closeChannelsDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.close.multiple.duration.micro")
|
||||
closeChannelsDurationMicro, err := meter.Int64Histogram("management.updatechannel.close.multiple.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
closeChannels, err := meter.SyncInt64().Histogram("management.updatechannel.close.multiple.channels")
|
||||
closeChannels, err := meter.Int64Histogram("management.updatechannel.close.multiple.channels")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sendUpdateDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.send.duration.micro")
|
||||
sendUpdateDurationMicro, err := meter.Int64Histogram("management.updatechannel.send.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getAllConnectedPeersDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.get.all.duration.micro")
|
||||
getAllConnectedPeersDurationMicro, err := meter.Int64Histogram("management.updatechannel.get.all.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getAllConnectedPeers, err := meter.SyncInt64().Histogram("management.updatechannel.get.all.peers")
|
||||
getAllConnectedPeers, err := meter.Int64Histogram("management.updatechannel.get.all.peers")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hasChannelDurationMicro, err := meter.SyncInt64().Histogram("management.updatechannel.haschannel.duration.micro")
|
||||
hasChannelDurationMicro, err := meter.Int64Histogram("management.updatechannel.haschannel.duration.micro")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,7 +79,8 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
|
||||
// CountCreateChannelDuration counts the duration of the CreateChannel method,
|
||||
// closed indicates if existing channel was closed before creation of a new one
|
||||
func (metrics *UpdateChannelMetrics) CountCreateChannelDuration(duration time.Duration, closed bool) {
|
||||
metrics.createChannelDurationMicro.Record(metrics.ctx, duration.Microseconds(), attribute.Bool("closed", closed))
|
||||
opts := metric.WithAttributeSet(attribute.NewSet(attribute.Bool("closed", closed)))
|
||||
metrics.createChannelDurationMicro.Record(metrics.ctx, duration.Microseconds(), opts)
|
||||
}
|
||||
|
||||
// CountCloseChannelDuration counts the duration of the CloseChannel method
|
||||
@@ -97,8 +97,8 @@ func (metrics *UpdateChannelMetrics) CountCloseChannelsDuration(duration time.Du
|
||||
// CountSendUpdateDuration counts the duration of the SendUpdate method
|
||||
// found indicates if peer had channel, dropped indicates if the message was dropped due channel buffer overload
|
||||
func (metrics *UpdateChannelMetrics) CountSendUpdateDuration(duration time.Duration, found, dropped bool) {
|
||||
attrs := []attribute.KeyValue{attribute.Bool("found", found), attribute.Bool("dropped", dropped)}
|
||||
metrics.sendUpdateDurationMicro.Record(metrics.ctx, duration.Microseconds(), attrs...)
|
||||
opts := metric.WithAttributeSet(attribute.NewSet(attribute.Bool("found", found), attribute.Bool("dropped", dropped)))
|
||||
metrics.sendUpdateDurationMicro.Record(metrics.ctx, duration.Microseconds(), opts)
|
||||
}
|
||||
|
||||
// CountGetAllConnectedPeersDuration counts the duration of the GetAllConnectedPeers method and the number of peers have been returned
|
||||
|
||||
45
management/server/testutil/store.go
Normal file
45
management/server/testutil/store.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build !ios
|
||||
// +build !ios
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
func CreatePGDB() (func(), error) {
|
||||
ctx := context.Background()
|
||||
c, err := postgres.RunContainer(ctx,
|
||||
testcontainers.WithImage("postgres:alpine"),
|
||||
postgres.WithDatabase("test"),
|
||||
postgres.WithUsername("postgres"),
|
||||
postgres.WithPassword("postgres"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).WithStartupTimeout(15*time.Second)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
timeout := 10 * time.Second
|
||||
err = c.Stop(ctx, &timeout)
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop container: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
talksConn, err := c.ConnectionString(ctx)
|
||||
if err != nil {
|
||||
return cleanup, err
|
||||
}
|
||||
return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn)
|
||||
}
|
||||
6
management/server/testutil/store_ios.go
Normal file
6
management/server/testutil/store_ios.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build ios
|
||||
// +build ios
|
||||
|
||||
package testutil
|
||||
|
||||
func CreatePGDB() (func(), error) { return func() {}, nil }
|
||||
@@ -910,8 +910,10 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) {
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock()
|
||||
defer unlock()
|
||||
log.Debugf("Acquired global lock in %s for user %s", time.Since(start), userID)
|
||||
|
||||
lowerDomain := strings.ToLower(domain)
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ const (
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -76,6 +77,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
@@ -97,6 +99,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
Id: mockTargetUserId,
|
||||
@@ -122,6 +125,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -140,6 +144,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
|
||||
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -158,6 +163,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
|
||||
func TestUser_DeletePAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
@@ -190,6 +196,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
|
||||
func TestUser_GetPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
@@ -221,6 +228,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
|
||||
func TestUser_GetAllPATs(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
@@ -322,6 +330,7 @@ func validateStruct(s interface{}) (err error) {
|
||||
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -359,6 +368,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -397,6 +407,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -421,6 +432,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
|
||||
func TestUser_InviteNewUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -549,6 +561,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
|
||||
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -569,6 +582,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
@@ -650,6 +664,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(account)
|
||||
@@ -678,6 +693,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
||||
@@ -790,6 +806,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
externalUser := &User{
|
||||
Id: "externalUser",
|
||||
@@ -853,6 +870,7 @@ func TestUser_IsAdmin(t *testing.T) {
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
@@ -880,6 +898,8 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
|
||||
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
|
||||
account := newAccountWithId(mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
Id: mockServiceUserID,
|
||||
|
||||
419
relay/client/client.go
Normal file
419
relay/client/client.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
ws "github.com/netbirdio/netbird/relay/client/dialer/wsnhooyr"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
serverResponseTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
// Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer.
|
||||
type Msg struct {
|
||||
Payload []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
bufPtr *[]byte
|
||||
}
|
||||
|
||||
func (m *Msg) Free() {
|
||||
m.bufPool.Put(m.bufPtr)
|
||||
}
|
||||
|
||||
type connContainer struct {
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
msgChanLock sync.Mutex
|
||||
closed bool // flag to check if channel is closed
|
||||
}
|
||||
|
||||
func newConnContainer(conn *Conn, messages chan Msg) *connContainer {
|
||||
return &connContainer{
|
||||
conn: conn,
|
||||
messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) writeMsg(msg Msg) {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
cc.messages <- msg
|
||||
}
|
||||
|
||||
func (cc *connContainer) close() {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
close(cc.messages)
|
||||
cc.closed = true
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
|
||||
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
|
||||
// While the Connect is in progress, the OpenConn function will block until the connection is established.
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
parentCtx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
serverAddress string
|
||||
hashedID []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[string]*connContainer
|
||||
serviceIsRunning bool
|
||||
mu sync.Mutex
|
||||
readLoopMutex sync.Mutex
|
||||
wgReadLoop sync.WaitGroup
|
||||
|
||||
remoteAddr net.Addr
|
||||
|
||||
onDisconnectListener func()
|
||||
listenerMutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
parentCtx: ctx,
|
||||
ctxCancel: func() {},
|
||||
serverAddress: serverAddress,
|
||||
hashedID: hashedID,
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
conns: make(map[string]*connContainer),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect() error {
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.serviceIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.serviceIsRunning = true
|
||||
|
||||
var ctx context.Context
|
||||
ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
||||
context.AfterFunc(ctx, func() {
|
||||
cErr := c.close(false)
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close relay connection: %s", cErr)
|
||||
}
|
||||
})
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||
// it will return immediately.
|
||||
// todo: what should happen if call with the same peerID with multiple times?
|
||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.serviceIsRunning {
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
|
||||
hashedID, hashedStringID := messages.HashID(dstPeerID)
|
||||
log.Infof("open connection to peer: %s", hashedStringID)
|
||||
msgChannel := make(chan Msg, 2)
|
||||
conn := NewConn(c, hashedID, hashedStringID, msgChannel)
|
||||
|
||||
c.conns[hashedStringID] = newConnContainer(conn, msgChannel)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// RelayRemoteAddress returns the IP address of the relay server. It could change after the close and reopen the connection.
|
||||
func (c *Client) RelayRemoteAddress() (net.Addr, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.remoteAddr == nil {
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
return c.remoteAddr, nil
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func()) {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
c.onDisconnectListener = fn
|
||||
}
|
||||
|
||||
func (c *Client) HasConns() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.conns) > 0
|
||||
}
|
||||
|
||||
// Close closes the connection to the relay server and all connections to other peers.
|
||||
func (c *Client) Close() error {
|
||||
return c.close(false)
|
||||
}
|
||||
|
||||
func (c *Client) close(byServer bool) error {
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
c.closeAllConns()
|
||||
if !byServer {
|
||||
c.writeCloseMsg()
|
||||
err = c.relayConn.Close()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay connection closed with: %s", c.serverAddress)
|
||||
c.ctxCancel()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
conn, err := ws.Dial(c.serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
err = c.handShake()
|
||||
if err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close connection: %s", cErr)
|
||||
}
|
||||
c.relayConn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
c.remoteAddr = conn.RemoteAddr()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
defer func() {
|
||||
err := c.relayConn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
log.Errorf("failed to reset read deadline: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
msg, err := messages.MarshalHelloMsg(c.hashedID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send hello message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.relayConn.SetReadDeadline(time.Now().Add(serverResponseTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("failed to set read deadline: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500) // todo: optimise buffer size
|
||||
n, err := c.relayConn.Read(buf)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read hello response: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||
if err != nil {
|
||||
log.Errorf("failed to determine message type: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeHelloResponse {
|
||||
log.Errorf("unexpected message type: %s", msgType)
|
||||
return fmt.Errorf("unexpected message type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
var (
|
||||
errExit error
|
||||
n int
|
||||
closedByServer bool
|
||||
)
|
||||
for {
|
||||
bufPtr := c.bufPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
goto Exit
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMsgType(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case messages.MsgTypeTransport:
|
||||
peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to parse transport message: %v", err)
|
||||
continue
|
||||
}
|
||||
stringID := messages.HashIDToString(peerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
goto Exit
|
||||
}
|
||||
container, ok := c.conns[stringID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", stringID)
|
||||
continue
|
||||
}
|
||||
|
||||
container.writeMsg(Msg{
|
||||
bufPool: c.bufPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: payload})
|
||||
case messages.MsgClose:
|
||||
closedByServer = true
|
||||
log.Debugf("relay connection close by server")
|
||||
goto Exit
|
||||
}
|
||||
}
|
||||
|
||||
Exit:
|
||||
c.notifyDisconnected()
|
||||
c.wgReadLoop.Done()
|
||||
_ = c.close(closedByServer)
|
||||
}
|
||||
|
||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
// conn, ok := c.conns[id]
|
||||
_, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
/*
|
||||
if conn != clientRef {
|
||||
return 0, io.EOF
|
||||
}
|
||||
*/
|
||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
||||
n, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
}
|
||||
|
||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||
func (c *Client) closeConn(id string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return fmt.Errorf("connection already closed")
|
||||
}
|
||||
container.close()
|
||||
delete(c.conns, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) onDisconnect() {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
c.onDisconnectListener()
|
||||
}
|
||||
|
||||
func (c *Client) notifyDisconnected() {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
go c.onDisconnectListener()
|
||||
}
|
||||
|
||||
func (c *Client) writeCloseMsg() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
_, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to send close message: %s", err)
|
||||
}
|
||||
}
|
||||
523
relay/client/client_test.go
Normal file
523
relay/client/client_test.go
Normal file
@@ -0,0 +1,523 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientAlice.Close()
|
||||
|
||||
clientPlaceHolder := NewClient(ctx, addr, "clientPlaceHolder")
|
||||
err = clientPlaceHolder.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
clientBob := NewClient(ctx, addr, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientBob.Close()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
log.Debugf("alice sent message to bob")
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
log.Debugf("on new message from alice to bob")
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
_ = srv.Close()
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
err = srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationTimeout(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind UDP server: %s", err)
|
||||
}
|
||||
defer func(fakeUDPListener *net.UDPConn) {
|
||||
_ = fakeUDPListener.Close()
|
||||
}(fakeUDPListener)
|
||||
|
||||
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind TCP server: %s", err)
|
||||
}
|
||||
defer func(fakeTCPListener *net.TCPListener) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
log.Debugf("%s", err)
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEcho(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, idAlice)
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Alice client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(ctx, addr, idBob)
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientBob.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Bob client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn(idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn(idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindToUnavailabePeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(ctx, addr, "bob")
|
||||
err = clientBob.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chBob, err := clientBob.OpenConn("alice")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client Alice")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(ctx, addr, "alice")
|
||||
err = clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chAlice, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
testString := "hello alice, I am bob"
|
||||
_, err = chBob.Write([]byte(testString))
|
||||
if err != nil {
|
||||
t.Errorf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := chAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if testString != string(buf[:n]) {
|
||||
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing connection")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = conn.Write([]byte("hello"))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected writing from closed connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRelayConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientAlice := NewClient(ctx, addr, "alice")
|
||||
err := clientAlice.Connect()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn("bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
_ = clientAlice.relayConn.Close()
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, addr1, idAlice)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
||||
err = srv1.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Fatalf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(ctx, addr1, idAlice)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
err = relayClient.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn("bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
|
||||
err = srv.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
67
relay/client/conn.go
Normal file
67
relay/client/conn.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
client *Client
|
||||
dstID []byte
|
||||
dstStringID string
|
||||
messageChan chan Msg
|
||||
}
|
||||
|
||||
func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg) *Conn {
|
||||
c := &Conn{
|
||||
client: client,
|
||||
dstID: dstID,
|
||||
dstStringID: dstStringID,
|
||||
messageChan: messageChan,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
return c.client.writeTo(c.dstStringID, c.dstID, p)
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
msg.Free()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.client.closeConn(c.dstStringID)
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.client.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.client.relayConn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
52
relay/client/dialer/quic/conn.go
Normal file
52
relay/client/dialer/quic/conn.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
quic.Stream
|
||||
qConn quic.Connection
|
||||
}
|
||||
|
||||
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
|
||||
return &Conn{
|
||||
Stream: stream,
|
||||
qConn: qConn,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Conn) Write(b []byte) (n int, err error) {
|
||||
log.Debugf("writing: %d, %x\n", len(b), b)
|
||||
n, err = q.Stream.Write(b)
|
||||
if n != len(b) {
|
||||
log.Errorf("failed to write out the full message")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (q *Conn) Close() error {
|
||||
err := q.Stream.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close stream: %s", err)
|
||||
return err
|
||||
}
|
||||
err = q.qConn.CloseWithError(0, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to close connection: %s", err)
|
||||
return err
|
||||
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.qConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.qConn.RemoteAddr()
|
||||
}
|
||||
32
relay/client/dialer/quic/quic.go
Normal file
32
relay/client/dialer/quic/quic.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
tlsConf := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"quic-echo-example"},
|
||||
}
|
||||
qConn, err := quic.DialAddr(context.Background(), address, tlsConf, &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("dial quic address %s failed: %s", address, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream, err := qConn.OpenStreamSync(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := NewConn(stream, qConn)
|
||||
return conn, nil
|
||||
}
|
||||
7
relay/client/dialer/tcp/tcp.go
Normal file
7
relay/client/dialer/tcp/tcp.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package tcp
|
||||
|
||||
import "net"
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
return net.Dial("tcp", address)
|
||||
}
|
||||
14
relay/client/dialer/udp/udp.go
Normal file
14
relay/client/dialer/udp/udp.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return net.DialUDP("udp", nil, udpAddr)
|
||||
}
|
||||
59
relay/client/dialer/ws/client_conn.go
Normal file
59
relay/client/dialer/ws/client_conn.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn) net.Conn {
|
||||
return &Conn{
|
||||
Conn: wsConn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, r, err := c.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if t != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
return r.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
err := c.WriteMessage(websocket.BinaryMessage, b)
|
||||
c.mu.Unlock()
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
errR := c.SetReadDeadline(t)
|
||||
errW := c.SetWriteDeadline(t)
|
||||
|
||||
if errR != nil {
|
||||
return errR
|
||||
}
|
||||
|
||||
if errW != nil {
|
||||
return errW
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
22
relay/client/dialer/ws/ws.go
Normal file
22
relay/client/dialer/ws/ws.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
addr := fmt.Sprintf("ws://" + address)
|
||||
wsDialer := websocket.Dialer{
|
||||
HandshakeTimeout: 3 * time.Second,
|
||||
}
|
||||
wsConn, _, err := wsDialer.Dial(addr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := NewConn(wsConn)
|
||||
return conn, nil
|
||||
}
|
||||
79
relay/client/dialer/wsnhooyr/client_conn.go
Normal file
79
relay/client/dialer/wsnhooyr/client_conn.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package wsnhooyr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn) net.Conn {
|
||||
return &Conn{
|
||||
Conn: wsConn,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
return ioReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||
return len(b), err
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
// todo: implement me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
// todo: implement me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
// todo: implement me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
// todo: implement me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
// todo: implement me
|
||||
errR := c.SetReadDeadline(t)
|
||||
errW := c.SetWriteDeadline(t)
|
||||
|
||||
if errR != nil {
|
||||
return errR
|
||||
}
|
||||
|
||||
if errW != nil {
|
||||
return errW
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
22
relay/client/dialer/wsnhooyr/ws.go
Normal file
22
relay/client/dialer/wsnhooyr/ws.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package wsnhooyr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"nhooyr.io/websocket"
|
||||
)
|
||||
|
||||
func Dial(address string) (net.Conn, error) {
|
||||
|
||||
addr := fmt.Sprintf("ws://" + address)
|
||||
wsConn, _, err := websocket.Dial(context.Background(), addr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
33
relay/client/guard.go
Normal file
33
relay/client/guard.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
reconnectingTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type Guard struct {
|
||||
ctx context.Context
|
||||
relayClient *Client
|
||||
}
|
||||
|
||||
func NewGuard(context context.Context, relayClient *Client) *Guard {
|
||||
g := &Guard{
|
||||
ctx: context,
|
||||
relayClient: relayClient,
|
||||
}
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *Guard) OnDisconnected() {
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
_ = g.relayClient.Connect()
|
||||
case <-g.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
208
relay/client/manager.go
Normal file
208
relay/client/manager.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
)
|
||||
|
||||
// RelayTrack hold the relay clients for the foregin relay servers.
|
||||
// With the mutex can ensure we can open new connection in case the relay connection has been established with
|
||||
// the relay server.
|
||||
type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{}
|
||||
}
|
||||
|
||||
// Manager is a manager for the relay client. It establish one persistent connection to the given relay server. In case
|
||||
// of network error the manager will try to reconnect to the server.
|
||||
// The manager also manage temproary relay connection. If a client wants to communicate with an another client on a
|
||||
// different relay server, the manager will establish a new connection to the relay server. The connection with these
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
srvAddress string
|
||||
peerID string
|
||||
|
||||
relayClient *Client
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
relayClientsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager {
|
||||
return &Manager{
|
||||
ctx: ctx,
|
||||
srvAddress: serverAddress,
|
||||
peerID: peerID,
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
}
|
||||
}
|
||||
|
||||
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop.
|
||||
// todo: consider to return an error if the initial connection to the relay server is not established.
|
||||
func (m *Manager) Serve() {
|
||||
m.relayClient = NewClient(m.ctx, m.srvAddress, m.peerID)
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnDisconnectListener(m.reconnectGuard.OnDisconnected)
|
||||
err := m.relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to relay server, keep try to reconnect: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.startCleanupLoop()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server.
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
if m.relayClient == nil {
|
||||
return nil, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !foreign {
|
||||
log.Debugf("open connection to permanent server: %s", peerKey)
|
||||
return m.relayClient.OpenConn(peerKey)
|
||||
} else {
|
||||
log.Debugf("open connection to foreign server: %s", serverAddress)
|
||||
return m.openConnVia(serverAddress, peerKey)
|
||||
}
|
||||
}
|
||||
|
||||
// RelayAddress returns the address of the permanent relay server. It could change if the network connection is lost.
|
||||
// This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayAddress() (net.Addr, error) {
|
||||
if m.relayClient == nil {
|
||||
return nil, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
return m.relayClient.RelayRemoteAddress()
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.RUnlock()
|
||||
defer rt.RUnlock()
|
||||
return rt.relayClient.OpenConn(peerKey)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// if not, establish a new connection but check it again (because changed the lock type) before starting the
|
||||
// connection
|
||||
m.relayClientsMutex.Lock()
|
||||
rt, ok = m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.Unlock()
|
||||
defer rt.RUnlock()
|
||||
return rt.relayClient.OpenConn(peerKey)
|
||||
}
|
||||
|
||||
// create a new relay client and store it in the relayClients map
|
||||
rt = NewRelayTrack()
|
||||
rt.Lock()
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(m.ctx, serverAddress, m.peerID)
|
||||
err := relayClient.Connect()
|
||||
if err != nil {
|
||||
rt.Unlock()
|
||||
m.relayClientsMutex.Lock()
|
||||
delete(m.relayClients, serverAddress)
|
||||
m.relayClientsMutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
m.deleteRelayConn(serverAddress)
|
||||
})
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
conn, err := relayClient.OpenConn(peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteRelayConn(address string) {
|
||||
log.Infof("deleting relay client for %s", address)
|
||||
m.relayClientsMutex.Lock()
|
||||
delete(m.relayClients, address)
|
||||
m.relayClientsMutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
rAddr, err := m.relayClient.RelayRemoteAddress()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
return rAddr.String() != address, nil
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
if m.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
rt.relayClient.SetOnDisconnectListener(nil)
|
||||
go func() {
|
||||
_ = rt.relayClient.Close()
|
||||
}()
|
||||
log.Debugf("clean up relay client: %s", addr)
|
||||
delete(m.relayClients, addr)
|
||||
rt.Unlock()
|
||||
}
|
||||
}
|
||||
271
relay/client/manager_test.go
Normal file
271
relay/client/manager_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
func TestForeignConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
err := srv2.Listen(addr2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, addr1, idAlice)
|
||||
clientAlice.Serve()
|
||||
|
||||
idBob := "bob"
|
||||
log.Debugf("connect by bob")
|
||||
clientBob := NewManager(mCtx, addr2, idBob)
|
||||
clientBob.Serve()
|
||||
|
||||
bobsSrvAddr, err := clientBob.RelayAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr.String(), idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginConnClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
err := srv2.Listen(addr2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, addr1, idAlice)
|
||||
mgr.Serve()
|
||||
|
||||
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
addr1 := "localhost:1234"
|
||||
srv1 := server.NewServer()
|
||||
go func() {
|
||||
t.Log("binding server 1.")
|
||||
err := srv1.Listen(addr1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
t.Logf("closing server 1.")
|
||||
err := srv1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 1. closed")
|
||||
}()
|
||||
|
||||
addr2 := "localhost:2234"
|
||||
srv2 := server.NewServer()
|
||||
go func() {
|
||||
t.Log("binding server 2.")
|
||||
err := srv2.Listen(addr2)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
t.Logf("closing server 2.")
|
||||
err := srv2.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 2 closed.")
|
||||
}()
|
||||
|
||||
idAlice := "alice"
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, addr1, idAlice)
|
||||
mgr.Serve()
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
conn, err := mgr.OpenConn(addr2, "anotherpeer")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("close conn")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
|
||||
time.Sleep(relayCleanupInterval + 1*time.Second)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
}
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
addr := "localhost:1234"
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
err := srv.Listen(addr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, addr, "alice")
|
||||
clientAlice.Serve()
|
||||
ra, err := clientAlice.RelayAddress()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ra.String(), "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("closing client relay connection")
|
||||
// todo figure out moc server
|
||||
_ = clientAlice.relayClient.relayConn.Close()
|
||||
t.Log("start test reading")
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ra.String(), "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
40
relay/cmd/main.go
Normal file
40
relay/cmd/main.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
util.InitLog("trace", "console")
|
||||
}
|
||||
|
||||
func waitForExitSignal() {
|
||||
osSigs := make(chan os.Signal, 1)
|
||||
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
|
||||
_ = <-osSigs
|
||||
}
|
||||
|
||||
func main() {
|
||||
address := "10.145.236.1:1235"
|
||||
srv := server.NewServer()
|
||||
err := srv.Listen(address)
|
||||
if err != nil {
|
||||
log.Errorf("failed to bind server: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
waitForExitSignal()
|
||||
|
||||
err = srv.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
20
relay/messages/id.go
Normal file
20
relay/messages/id.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
const (
|
||||
IDSize = sha256.Size
|
||||
)
|
||||
|
||||
func HashID(peerID string) ([]byte, string) {
|
||||
idHash := sha256.Sum256([]byte(peerID))
|
||||
idHashString := base64.StdEncoding.EncodeToString(idHash[:])
|
||||
return idHash[:], idHashString
|
||||
}
|
||||
|
||||
func HashIDToString(idHash []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(idHash[:])
|
||||
}
|
||||
137
relay/messages/message.go
Normal file
137
relay/messages/message.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package messages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
MsgTypeHello MsgType = 0
|
||||
MsgTypeHelloResponse MsgType = 1
|
||||
MsgTypeTransport MsgType = 2
|
||||
MsgClose MsgType = 3
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidMessageLength = fmt.Errorf("invalid message length")
|
||||
)
|
||||
|
||||
type MsgType byte
|
||||
|
||||
func (m MsgType) String() string {
|
||||
switch m {
|
||||
case MsgTypeHello:
|
||||
return "hello"
|
||||
case MsgTypeHelloResponse:
|
||||
return "hello response"
|
||||
case MsgTypeTransport:
|
||||
return "transport"
|
||||
case MsgClose:
|
||||
return "close"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func DetermineClientMsgType(msg []byte) (MsgType, error) {
|
||||
// todo: validate magic byte
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case MsgTypeHello:
|
||||
return msgType, nil
|
||||
case MsgTypeTransport:
|
||||
return msgType, nil
|
||||
case MsgClose:
|
||||
return msgType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid msg type, len: %d", len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func DetermineServerMsgType(msg []byte) (MsgType, error) {
|
||||
// todo: validate magic byte
|
||||
msgType := MsgType(msg[0])
|
||||
switch msgType {
|
||||
case MsgTypeHelloResponse:
|
||||
return msgType, nil
|
||||
case MsgTypeTransport:
|
||||
return msgType, nil
|
||||
case MsgClose:
|
||||
return msgType, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalHelloMsg initial hello message
|
||||
func MarshalHelloMsg(peerID []byte) ([]byte, error) {
|
||||
if len(peerID) != IDSize {
|
||||
return nil, fmt.Errorf("invalid peerID length")
|
||||
}
|
||||
msg := make([]byte, 1, 1+len(peerID))
|
||||
msg[0] = byte(MsgTypeHello)
|
||||
msg = append(msg, peerID...)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func UnmarshalHelloMsg(msg []byte) ([]byte, error) {
|
||||
if len(msg) < 2 {
|
||||
return nil, fmt.Errorf("invalid 'hello' messge")
|
||||
}
|
||||
return msg[1:], nil
|
||||
}
|
||||
|
||||
func MarshalHelloResponse() []byte {
|
||||
msg := make([]byte, 1)
|
||||
msg[0] = byte(MsgTypeHelloResponse)
|
||||
return msg
|
||||
}
|
||||
|
||||
// Close message
|
||||
|
||||
func MarshalCloseMsg() []byte {
|
||||
msg := make([]byte, 1)
|
||||
msg[0] = byte(MsgClose)
|
||||
return msg
|
||||
}
|
||||
|
||||
// Transport message
|
||||
|
||||
func MarshalTransportMsg(peerID []byte, payload []byte) []byte {
|
||||
if len(peerID) != IDSize {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := make([]byte, 1+IDSize, 1+IDSize+len(payload))
|
||||
msg[0] = byte(MsgTypeTransport)
|
||||
copy(msg[1:], peerID)
|
||||
msg = append(msg, payload...)
|
||||
return msg
|
||||
}
|
||||
|
||||
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
|
||||
headerSize := 1 + IDSize
|
||||
if len(buf) < headerSize {
|
||||
return nil, nil, ErrInvalidMessageLength
|
||||
}
|
||||
|
||||
return buf[1:headerSize], buf[headerSize:], nil
|
||||
}
|
||||
|
||||
func UnmarshalTransportID(buf []byte) ([]byte, error) {
|
||||
headerSize := 1 + IDSize
|
||||
if len(buf) < headerSize {
|
||||
log.Debugf("invalid message length: %d, expected: %d, %x", len(buf), headerSize, buf)
|
||||
return nil, ErrInvalidMessageLength
|
||||
}
|
||||
return buf[1:headerSize], nil
|
||||
}
|
||||
|
||||
func UpdateTransportMsg(msg []byte, peerID []byte) error {
|
||||
if len(msg) < 1+len(peerID) {
|
||||
return ErrInvalidMessageLength
|
||||
}
|
||||
copy(msg[1:], peerID)
|
||||
return nil
|
||||
}
|
||||
9
relay/server/listener/listener.go
Normal file
9
relay/server/listener/listener.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package listener
|
||||
|
||||
import "net"
|
||||
|
||||
type Listener interface {
|
||||
Listen(func(conn net.Conn)) error
|
||||
Close() error
|
||||
WaitForExitAcceptedConns()
|
||||
}
|
||||
36
relay/server/listener/quic/conn.go
Normal file
36
relay/server/listener/quic/conn.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type QuicConn struct {
|
||||
quic.Stream
|
||||
qConn quic.Connection
|
||||
}
|
||||
|
||||
func NewConn(stream quic.Stream, qConn quic.Connection) net.Conn {
|
||||
return &QuicConn{
|
||||
Stream: stream,
|
||||
qConn: qConn,
|
||||
}
|
||||
}
|
||||
|
||||
func (q QuicConn) Write(b []byte) (n int, err error) {
|
||||
n, err = q.Stream.Write(b)
|
||||
if n != len(b) {
|
||||
log.Errorf("failed to write out the full message")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (q QuicConn) LocalAddr() net.Addr {
|
||||
return q.qConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (q QuicConn) RemoteAddr() net.Addr {
|
||||
return q.qConn.RemoteAddr()
|
||||
}
|
||||
111
relay/server/listener/quic/listener.go
Normal file
111
relay/server/listener/quic/listener.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
address string
|
||||
onAcceptFn func(conn net.Conn)
|
||||
|
||||
listener *quic.Listener
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
|
||||
ql, err := quic.ListenAddr(l.address, generateTLSConfig(), &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.listener = ql
|
||||
l.quit = make(chan struct{})
|
||||
|
||||
log.Infof("quic server is listening on address: %s", l.address)
|
||||
l.wg.Add(1)
|
||||
go l.acceptLoop(onAcceptFn)
|
||||
|
||||
<-l.quit
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
close(l.quit)
|
||||
err := l.listener.Close()
|
||||
l.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) acceptLoop(acceptFn func(conn net.Conn)) {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
qConn, err := l.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.quit:
|
||||
return
|
||||
default:
|
||||
log.Errorf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("new connection from: %s", qConn.RemoteAddr())
|
||||
|
||||
stream, err := qConn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
log.Errorf("failed to open stream: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
conn := NewConn(stream, qConn)
|
||||
|
||||
go acceptFn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Setup a bare-bones TLS config for the server
|
||||
func generateTLSConfig() *tls.Config {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
NextProtos: []string{"quic-echo-example"},
|
||||
}
|
||||
}
|
||||
80
relay/server/listener/tcp/listener.go
Normal file
80
relay/server/listener/tcp/listener.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
// Listener
|
||||
// Is it just demo code. It does not work in real life environment because the TCP is a streaming protocol, adn
|
||||
// it does not handle framing.
|
||||
type Listener struct {
|
||||
address string
|
||||
|
||||
onAcceptFn func(conn net.Conn)
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
listener net.Listener
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
|
||||
l.lock.Lock()
|
||||
|
||||
l.onAcceptFn = onAcceptFn
|
||||
l.quit = make(chan struct{})
|
||||
|
||||
li, err := net.Listen("tcp", l.address)
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on address: %s, %s", l.address, err)
|
||||
l.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
log.Debugf("TCP server is listening on address: %s", l.address)
|
||||
l.listener = li
|
||||
l.wg.Add(1)
|
||||
go l.acceptLoop()
|
||||
|
||||
l.lock.Unlock()
|
||||
<-l.quit
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close todo: prevent multiple call (do not close two times the channel)
|
||||
func (l *Listener) Close() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
close(l.quit)
|
||||
err := l.listener.Close()
|
||||
l.wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) acceptLoop() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.quit:
|
||||
return
|
||||
default:
|
||||
log.Errorf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
go l.onAcceptFn(conn)
|
||||
}
|
||||
}
|
||||
68
relay/server/listener/udp/conn.go
Normal file
68
relay/server/listener/udp/conn.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UDPConn struct {
|
||||
*net.UDPConn
|
||||
addr *net.UDPAddr
|
||||
msgChannel chan []byte
|
||||
}
|
||||
|
||||
func NewConn(conn *net.UDPConn, addr *net.UDPAddr) *UDPConn {
|
||||
return &UDPConn{
|
||||
UDPConn: conn,
|
||||
addr: addr,
|
||||
msgChannel: make(chan []byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UDPConn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-u.msgChannel
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(b, msg)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) Write(b []byte) (n int, err error) {
|
||||
return u.UDPConn.WriteTo(b, u.addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) Close() error {
|
||||
//TODO implement me
|
||||
//panic("implement me")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UDPConn) LocalAddr() net.Addr {
|
||||
return u.UDPConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (u *UDPConn) RemoteAddr() net.Addr {
|
||||
return u.addr
|
||||
}
|
||||
|
||||
func (u *UDPConn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (u *UDPConn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (u *UDPConn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (u *UDPConn) onNewMsg(b []byte) {
|
||||
u.msgChannel <- b
|
||||
}
|
||||
109
relay/server/listener/udp/listener.go
Normal file
109
relay/server/listener/udp/listener.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
address string
|
||||
conns map[string]*UDPConn
|
||||
onAcceptFn func(conn net.Conn)
|
||||
|
||||
listener *net.UDPConn
|
||||
|
||||
wg sync.WaitGroup
|
||||
quit chan struct{}
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (l *Listener) WaitForExitAcceptedConns() {
|
||||
l.wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
func NewListener(address string) listener.Listener {
|
||||
return &Listener{
|
||||
address: address,
|
||||
conns: make(map[string]*UDPConn),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Listen(onAcceptFn func(conn net.Conn)) error {
|
||||
l.lock.Lock()
|
||||
|
||||
l.onAcceptFn = onAcceptFn
|
||||
l.quit = make(chan struct{})
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", l.address)
|
||||
if err != nil {
|
||||
log.Errorf("invalid listen address '%s': %s", l.address, err)
|
||||
l.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
li, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
log.Fatalf("%s", err)
|
||||
l.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
log.Debugf("udp server is listening on address: %s", addr.String())
|
||||
l.listener = li
|
||||
l.wg.Add(1)
|
||||
go l.readLoop()
|
||||
|
||||
l.lock.Unlock()
|
||||
<-l.quit
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) Close() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.listener == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("closing UDP listener")
|
||||
close(l.quit)
|
||||
err := l.listener.Close()
|
||||
l.wg.Wait()
|
||||
l.listener = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *Listener) readLoop() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
buf := make([]byte, 1500)
|
||||
n, addr, err := l.listener.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.quit:
|
||||
return
|
||||
default:
|
||||
log.Errorf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
pConn, ok := l.conns[addr.String()]
|
||||
if ok {
|
||||
pConn.onNewMsg(buf[:n])
|
||||
continue
|
||||
}
|
||||
|
||||
pConn = NewConn(l.listener, addr)
|
||||
log.Infof("new connection from: %s", pConn.RemoteAddr())
|
||||
l.conns[addr.String()] = pConn
|
||||
go l.onAcceptFn(pConn)
|
||||
pConn.onNewMsg(buf[:n])
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user