mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Compare commits
87 Commits
feat-clien
...
fix-rollba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ecba370a3 | ||
|
|
251f2d7bc2 | ||
|
|
306e02d32b | ||
|
|
8375491708 | ||
|
|
e197b89ac3 | ||
|
|
6aba28ccb7 | ||
|
|
8f9826b207 | ||
|
|
0aad9169e9 | ||
|
|
1057cd211d | ||
|
|
32b345991a | ||
|
|
e903522f8c | ||
|
|
ea88ec6d27 | ||
|
|
2be1a82f4a | ||
|
|
fe1ea4a2d0 | ||
|
|
f14f34cf2b | ||
|
|
109481e26d | ||
|
|
18098e7a7d | ||
|
|
5993982cca | ||
|
|
86f9051a30 | ||
|
|
489892553a | ||
|
|
b05e30ac5a | ||
|
|
769388cd21 | ||
|
|
c54fb9643c | ||
|
|
5dc0ff42a5 | ||
|
|
45badd2c39 | ||
|
|
d3de035961 | ||
|
|
b2da0ae70f | ||
|
|
931c20c8fe | ||
|
|
2eaf4aa8d7 | ||
|
|
110067c00f | ||
|
|
32c96c15b8 | ||
|
|
ca1dc5ac88 | ||
|
|
ce775d59ae | ||
|
|
f273fe9f51 | ||
|
|
e08af7fcdf | ||
|
|
454240ca05 | ||
|
|
1343a3f00e | ||
|
|
2a79995706 | ||
|
|
e869882da1 | ||
|
|
6c8bb60632 | ||
|
|
4d7029d80c | ||
|
|
909f305728 | ||
|
|
5e2f66d591 | ||
|
|
a7519859bc | ||
|
|
9b000b89d5 | ||
|
|
5c1acdbf2f | ||
|
|
db3a9f0aa2 | ||
|
|
ecc4f8a10d | ||
|
|
03abdfa112 | ||
|
|
9746a7f61a | ||
|
|
4ec6d5d20b | ||
|
|
3bab745142 | ||
|
|
0ca3d27a80 | ||
|
|
c5942e6b33 | ||
|
|
726ffb5740 | ||
|
|
dfb7960cd4 | ||
|
|
ab0cf1b8aa | ||
|
|
8ebd6ce963 | ||
|
|
42ba0765c8 | ||
|
|
514403db37 | ||
|
|
488d338ce8 | ||
|
|
6a75ec4ab7 | ||
|
|
b66e984ddd | ||
|
|
c65a934107 | ||
|
|
55ebf93815 | ||
|
|
9e74f30d2f | ||
|
|
71d24e59e6 | ||
|
|
992cfe64e1 | ||
|
|
d1703479ff | ||
|
|
a27fe4326c | ||
|
|
e6292e3124 | ||
|
|
628b497e81 | ||
|
|
8f66dea11c | ||
|
|
de8608f99f | ||
|
|
9c5adfea2b | ||
|
|
8e4710763e | ||
|
|
82af60838e | ||
|
|
311b67fe5a | ||
|
|
94d39ab48c | ||
|
|
41a47be379 | ||
|
|
e30def175b | ||
|
|
e1ef091d45 | ||
|
|
511ba6d51f | ||
|
|
b852198f67 | ||
|
|
628a201e31 | ||
|
|
453643683d | ||
|
|
b8cab2882b |
58
.github/workflows/install-test-darwin.yml
vendored
Normal file
58
.github/workflows/install-test-darwin.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
name: Test installation Darwin
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
env:
|
||||
SKIP_UI_APP: true
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
install-all:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename brew package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: mv /opt/homebrew/bin/brew /opt/homebrew/bin/brew.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $(mdfind "kMDItemContentType == 'com.apple.application-bundle' && kMDItemFSName == '*NetBird UI.app'") ]]; then
|
||||
echo "Error: NetBird UI is not installed"
|
||||
exit 1
|
||||
fi
|
||||
36
.github/workflows/install-test-linux.yml
vendored
Normal file
36
.github/workflows/install-test-linux.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Test installation Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- "release_files/install.sh"
|
||||
|
||||
jobs:
|
||||
install-cli-only:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
check_bin_install: [true, false]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Rename apt package
|
||||
if: ${{ matrix.check_bin_install }}
|
||||
run: |
|
||||
sudo mv /usr/bin/apt /usr/bin/apt.bak
|
||||
sudo mv /usr/bin/apt-get /usr/bin/apt-get.bak
|
||||
|
||||
- name: Run install script
|
||||
run: |
|
||||
sh ./release_files/install.sh
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if ! command -v netbird &> /dev/null; then
|
||||
echo "Error: netbird is not installed"
|
||||
exit 1
|
||||
fi
|
||||
11
.github/workflows/test-docker-compose-linux.yml
vendored
11
.github/workflows/test-docker-compose-linux.yml
vendored
@@ -59,6 +59,11 @@ jobs:
|
||||
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
|
||||
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
|
||||
CI_NETBIRD_TOKEN_SOURCE: "idToken"
|
||||
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
|
||||
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
|
||||
|
||||
run: |
|
||||
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
|
||||
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
|
||||
@@ -68,6 +73,12 @@ jobs:
|
||||
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
|
||||
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
|
||||
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
|
||||
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
|
||||
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
|
||||
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
|
||||
grep -A 1 ProviderConfig management.json | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
grep Scope management.json | grep "$CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE"
|
||||
grep UseIDToken management.json | grep false
|
||||
|
||||
- name: run docker compose up
|
||||
working-directory: infrastructure_files
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -23,6 +24,11 @@ type TunAdapter interface {
|
||||
iface.TunAdapter
|
||||
}
|
||||
|
||||
// IFaceDiscover export internal IFaceDiscover for mobile
|
||||
type IFaceDiscover interface {
|
||||
stdnet.IFaceDiscover
|
||||
}
|
||||
|
||||
func init() {
|
||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||
}
|
||||
@@ -31,6 +37,7 @@ func init() {
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter iface.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
ctxCancel context.CancelFunc
|
||||
ctxCancelLock *sync.Mutex
|
||||
@@ -38,7 +45,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
|
||||
func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover) *Client {
|
||||
lvl, _ := log.ParseLevel("trace")
|
||||
log.SetLevel(lvl)
|
||||
|
||||
@@ -46,6 +53,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter) *Client {
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
tunAdapter: tunAdapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
recorder: peer.NewRecorder(""),
|
||||
ctxCancelLock: &sync.Mutex{},
|
||||
}
|
||||
@@ -70,14 +78,14 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
auth := NewAuthWithConfig(ctx, cfg)
|
||||
err = auth.Login(urlOpener)
|
||||
err = auth.login(urlOpener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter)
|
||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -110,12 +118,12 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
// AddConnectionListener add new network connection listener
|
||||
func (c *Client) AddConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.AddConnectionListener(listener)
|
||||
// SetConnectionListener set the network connection listener
|
||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.SetConnectionListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove connection listener
|
||||
func (c *Client) RemoveConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.RemoveConnectionListener(listener)
|
||||
func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
@@ -3,16 +3,32 @@ package android
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// SSOListener is async listener for mobile framework
|
||||
type SSOListener interface {
|
||||
OnSuccess(bool)
|
||||
OnError(error)
|
||||
}
|
||||
|
||||
// ErrListener is async listener for mobile framework
|
||||
type ErrListener interface {
|
||||
OnSuccess()
|
||||
OnError(error)
|
||||
}
|
||||
|
||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
@@ -52,32 +68,66 @@ func NewAuthWithConfig(ctx context.Context, config *internal.Config) *Auth {
|
||||
}
|
||||
}
|
||||
|
||||
// LoginAndSaveConfigIfSSOSupported test the connectivity with the management server.
|
||||
// If the SSO is supported than save the configuration. Return with the SSO login is supported or not.
|
||||
func (a *Auth) LoginAndSaveConfigIfSSOSupported() (bool, error) {
|
||||
var needsLogin bool
|
||||
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
|
||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
||||
go func() {
|
||||
sso, err := a.saveConfigIfSSOSupported()
|
||||
if err != nil {
|
||||
listener.OnError(err)
|
||||
} else {
|
||||
listener.OnSuccess(sso)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||
supportsSSO := true
|
||||
err := a.withBackOff(a.ctx, func() (err error) {
|
||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
||||
return
|
||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
supportsSSO = false
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
if !supportsSSO {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
||||
}
|
||||
if !needsLogin {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = internal.WriteOutConfig(a.cfgPath, a.config)
|
||||
return needsLogin, err
|
||||
return true, err
|
||||
}
|
||||
|
||||
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
|
||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error {
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
err := internal.Login(a.ctx, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||
return nil
|
||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
|
||||
go func() {
|
||||
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
resultListener.OnSuccess()
|
||||
}
|
||||
return err
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||
//nolint
|
||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||
|
||||
err := a.withBackOff(a.ctx, func() error {
|
||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
||||
// we got an answer from management, exit backoff earlier
|
||||
return backoff.Permanent(backoffErr)
|
||||
}
|
||||
return backoffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
||||
@@ -87,7 +137,18 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string) error {
|
||||
}
|
||||
|
||||
// Login try register the client on the server
|
||||
func (a *Auth) Login(urlOpener URLOpener) error {
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
go func() {
|
||||
err := a.login(urlOpener)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
resultListener.OnSuccess()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener) error {
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
@@ -105,7 +166,7 @@ func (a *Auth) Login(urlOpener URLOpener) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.AccessToken
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = a.withBackOff(a.ctx, func() error {
|
||||
@@ -138,12 +199,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*internal.TokenInfo,
|
||||
}
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(
|
||||
providerConfig.ProviderConfig.Audience,
|
||||
providerConfig.ProviderConfig.ClientID,
|
||||
providerConfig.ProviderConfig.TokenEndpoint,
|
||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||
)
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
if err != nil {
|
||||
|
||||
@@ -135,7 +135,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
jwtToken = tokenInfo.AccessToken
|
||||
jwtToken = tokenInfo.GetTokenToUse()
|
||||
}
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
@@ -172,12 +172,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
}
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(
|
||||
providerConfig.ProviderConfig.Audience,
|
||||
providerConfig.ProviderConfig.ClientID,
|
||||
providerConfig.ProviderConfig.TokenEndpoint,
|
||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||
)
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
if err != nil {
|
||||
|
||||
@@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
@@ -39,12 +39,8 @@ const (
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
port *Port,
|
||||
direction Direction,
|
||||
action Action,
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,33 +17,17 @@ const (
|
||||
ChainFilterName = "NETBIRD-ACL"
|
||||
)
|
||||
|
||||
// jumpNetbirdDefaultRule always added by manager to the input chain for all trafic from the Netbird interface
|
||||
var jumpNetbirdDefaultRule = []string{
|
||||
"-j", ChainFilterName, "-m", "comment", "--comment", "Netbird traffic chain jump"}
|
||||
|
||||
// pingSupportDefaultRule always added by the manager to the Netbird ACL chain
|
||||
var pingSupportDefaultRule = []string{
|
||||
"-p", "icmp", "--icmp-type", "echo-request", "-j",
|
||||
"ACCEPT", "-m", "comment", "--comment", "Allow pings from the Netbird Devices"}
|
||||
|
||||
// dropAllDefaultRule in the Netbird chain
|
||||
var dropAllDefaultRule = []string{"-j", "DROP"}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
ipv6Client *iptables.IPTables
|
||||
|
||||
wgIfaceName string
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIfaceName string) (*Manager, error) {
|
||||
m := &Manager{
|
||||
wgIfaceName: wgIfaceName,
|
||||
}
|
||||
func Create() (*Manager, error) {
|
||||
m := &Manager{}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
@@ -55,23 +38,20 @@ func Create(wgIfaceName string) (*Manager, error) {
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Errorf("ip6tables is not installed in the system or not supported")
|
||||
} else {
|
||||
m.ipv6Client = ipv6Client
|
||||
return nil, fmt.Errorf("ip6tables is not installed in the system or not supported")
|
||||
}
|
||||
m.ipv6Client = ipv6Client
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
||||
return nil, fmt.Errorf("failed to reset firewall: %s", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
protocol fw.Protocol,
|
||||
port *fw.Port,
|
||||
direction fw.Direction,
|
||||
action fw.Action,
|
||||
@@ -79,64 +59,45 @@ func (m *Manager) AddFiltering(
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
client, err := m.client(ip)
|
||||
client := m.client(ip)
|
||||
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||
}
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||
}
|
||||
}
|
||||
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||
return nil, fmt.Errorf("invalid port definition")
|
||||
}
|
||||
pv := strconv.Itoa(port.Values[0])
|
||||
if port.IsRange {
|
||||
pv += ":" + strconv.Itoa(port.Values[1])
|
||||
}
|
||||
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var portValue string
|
||||
if port != nil && port.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
portValue = strconv.Itoa(port.Values[0])
|
||||
rule := &Rule{
|
||||
id: uuid.New().String(),
|
||||
specs: specs,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs(
|
||||
"filter",
|
||||
ChainFilterName,
|
||||
ip,
|
||||
string(protocol),
|
||||
portValue,
|
||||
direction,
|
||||
action,
|
||||
comment,
|
||||
)
|
||||
|
||||
ok, err := client.Exists("filter", ChainFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Rule{id: ruleID, specs: specs, v6: ip.To4() == nil}, nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
client := m.ipv4Client
|
||||
if r.v6 {
|
||||
if m.ipv6Client == nil {
|
||||
return fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
client = m.ipv6Client
|
||||
}
|
||||
return client.Delete("filter", ChainFilterName, r.specs...)
|
||||
@@ -146,14 +107,11 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err)
|
||||
}
|
||||
if m.ipv6Client != nil {
|
||||
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||
}
|
||||
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -167,79 +125,31 @@ func (m *Manager) reset(client *iptables.IPTables, table, chain string) error {
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
specs := append([]string{"-i", m.wgIfaceName}, jumpNetbirdDefaultRule...)
|
||||
if err := client.Delete("filter", "INPUT", specs...); err != nil {
|
||||
return fmt.Errorf("failed to delete default rule: %w", err)
|
||||
if err := client.ClearChain(table, ChainFilterName); err != nil {
|
||||
return fmt.Errorf("failed to clear chain: %w", err)
|
||||
}
|
||||
|
||||
return client.ClearAndDeleteChain(table, chain)
|
||||
return client.DeleteChain(table, ChainFilterName)
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
table string, chain string, ip net.IP, protocol string, port string,
|
||||
table string, chain string, ip net.IP, port string,
|
||||
direction fw.Direction, action fw.Action, comment string,
|
||||
) (specs []string) {
|
||||
switch direction {
|
||||
case fw.DirectionSrc:
|
||||
if direction == fw.DirectionSrc {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
case fw.DirectionDst:
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
if protocol != "" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if port != "" {
|
||||
specs = append(specs, "--dport", port)
|
||||
}
|
||||
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||
specs = append(specs, "-j", m.actionToStr(action))
|
||||
return append(specs, "-m", "comment", "--comment", comment)
|
||||
}
|
||||
|
||||
// rawClient returns corresponding iptables client for the given ip
|
||||
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
|
||||
// client returns corresponding iptables client for the given ip
|
||||
func (m *Manager) client(ip net.IP) *iptables.IPTables {
|
||||
if ip.To4() != nil {
|
||||
return m.ipv4Client, nil
|
||||
return m.ipv4Client
|
||||
}
|
||||
if m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
return m.ipv6Client, nil
|
||||
}
|
||||
|
||||
// client returns client with initialized chain and default rules
|
||||
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
||||
client, err := m.rawClient(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainFilterName, pingSupportDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default ping allow rule: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird chain: %w", err)
|
||||
}
|
||||
|
||||
specs := append([]string{"-i", m.wgIfaceName}, jumpNetbirdDefaultRule...)
|
||||
if err := client.AppendUnique("filter", "INPUT", specs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create chain: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
return client, nil
|
||||
return m.ipv6Client
|
||||
}
|
||||
|
||||
func (m *Manager) actionToStr(action fw.Action) string {
|
||||
|
||||
@@ -14,8 +14,7 @@ func TestNewManager(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create("lo")
|
||||
manager, err := Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -23,11 +22,10 @@ func TestNewManager(t *testing.T) {
|
||||
var rule1 fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...)
|
||||
@@ -37,13 +35,13 @@ func TestNewManager(t *testing.T) {
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Proto: fw.PortProtocolTCP,
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...)
|
||||
@@ -60,7 +58,6 @@ func TestNewManager(t *testing.T) {
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule2); err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...)
|
||||
@@ -69,22 +66,19 @@ func TestNewManager(t *testing.T) {
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("failed to reset: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
t.Errorf("failed to drop chain: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if ok {
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
package firewall
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
// PortProtocol is the protocol of the port
|
||||
type PortProtocol string
|
||||
|
||||
const (
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
// PortProtocolTCP is the TCP protocol
|
||||
PortProtocolTCP PortProtocol = "tcp"
|
||||
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
// PortProtocolUDP is the UDP protocol
|
||||
PortProtocolUDP PortProtocol = "udp"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
@@ -21,4 +18,7 @@ type Port struct {
|
||||
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
|
||||
// Proto is the protocol of the port
|
||||
Proto PortProtocol
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -22,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) error {
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -58,9 +59,6 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return err
|
||||
}
|
||||
|
||||
statusRecorder.MarkManagementDisconnected()
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
defer statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
@@ -81,12 +79,12 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
log.Debugf("conecting to the Management service %s", config.ManagementURL.Host)
|
||||
mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
||||
mgmClient.SetConnStateListener(mgmNotifier)
|
||||
|
||||
if err != nil {
|
||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
||||
}
|
||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
|
||||
mgmClient.SetConnStateListener(mgmNotifier)
|
||||
|
||||
log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
|
||||
defer func() {
|
||||
err = mgmClient.Close()
|
||||
@@ -146,7 +144,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter)
|
||||
engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig, tunAdapter, iFaceDiscover)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
@@ -162,7 +160,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
|
||||
<-engineCtx.Done()
|
||||
statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
@@ -193,12 +194,13 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
}
|
||||
|
||||
// createEngineConfig converts configuration received from Management Service to EngineConfig
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter) (*EngineConfig, error) {
|
||||
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*EngineConfig, error) {
|
||||
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
|
||||
@@ -34,6 +34,10 @@ type ProviderConfig struct {
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
@@ -91,9 +95,16 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmU
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
||||
},
|
||||
}
|
||||
|
||||
// keep compatibility with older management versions
|
||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
||||
}
|
||||
|
||||
err = isProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
||||
if err != nil {
|
||||
return DeviceAuthorizationFlow{}, err
|
||||
@@ -116,5 +127,8 @@ func isProviderConfigValid(config ProviderConfig) error {
|
||||
if config.DeviceAuthEndpoint == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
||||
}
|
||||
if config.Scope == "" {
|
||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,11 +16,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -51,6 +50,8 @@ type EngineConfig struct {
|
||||
// TunAdapter is option. It is necessary for mobile version.
|
||||
TunAdapter iface.TunAdapter
|
||||
|
||||
IFaceDiscover stdnet.IFaceDiscover
|
||||
|
||||
// WgAddr is a Wireguard local address (Netbird Network IP)
|
||||
WgAddr string
|
||||
|
||||
@@ -114,9 +115,7 @@ type Engine struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
routeManager routemanager.Manager
|
||||
firewallManager firewall.Manager
|
||||
firewallRules map[string]firewall.Rule
|
||||
routeManager routemanager.Manager
|
||||
|
||||
dnsServer dns.Server
|
||||
}
|
||||
@@ -190,12 +189,22 @@ func (e *Engine) Start() error {
|
||||
networkName = "udp4"
|
||||
}
|
||||
|
||||
transportNet, err := e.newStdNet()
|
||||
if err != nil {
|
||||
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
e.udpMuxConn, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxPort})
|
||||
if err != nil {
|
||||
log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error())
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
udpMuxParams := ice.UDPMuxParams{
|
||||
UDPConn: e.udpMuxConn,
|
||||
Net: transportNet,
|
||||
}
|
||||
e.udpMux = ice.NewUDPMuxDefault(udpMuxParams)
|
||||
|
||||
e.udpMuxConnSrflx, err = net.ListenUDP(networkName, &net.UDPAddr{Port: e.config.UDPMuxSrflxPort})
|
||||
if err != nil {
|
||||
@@ -203,9 +212,7 @@ func (e *Engine) Start() error {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: e.udpMuxConn})
|
||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx})
|
||||
e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx, Net: transportNet})
|
||||
|
||||
err = e.wgInterface.Create()
|
||||
if err != nil {
|
||||
@@ -223,16 +230,6 @@ func (e *Engine) Start() error {
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder)
|
||||
|
||||
e.firewallManager, err = buildFirewallManager(wgIfaceName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create firewall manager, ACL policy will not work: %s", err.Error())
|
||||
} else {
|
||||
if err := e.firewallManager.Reset(); err != nil {
|
||||
log.Tracef("failed to reset firewall manager on the start: %v", err.Error())
|
||||
}
|
||||
}
|
||||
e.firewallRules = make(map[string]firewall.Rule)
|
||||
|
||||
if e.dnsServer == nil {
|
||||
// todo fix custom address
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||
@@ -636,7 +633,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
e.applyFirewallRules(networkMap.FirewallRules)
|
||||
e.networkSerial = serial
|
||||
return nil
|
||||
}
|
||||
@@ -828,7 +824,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder)
|
||||
peerConn, err := peer.NewConn(config, e.statusRecorder, e.config.TunAdapter, e.config.IFaceDiscover)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1043,92 +1039,6 @@ func (e *Engine) close() {
|
||||
e.dnsServer.Stop()
|
||||
}
|
||||
|
||||
if e.firewallManager != nil {
|
||||
if err := e.firewallManager.Reset(); err != nil {
|
||||
log.Warnf("failed resetting firewall: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyFirewallRules to the local firewall manager processed by ACL policy.
|
||||
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) {
|
||||
if e.firewallManager == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return
|
||||
}
|
||||
|
||||
for ruleID, rule := range e.firewallRules {
|
||||
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(e.firewallRules, ruleID)
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
if rule := e.protoRuleToFirewallRule(r); rule == nil {
|
||||
log.Errorf("failed to apply firewall rule: %v", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
log.Error("invalid IP address, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
split := strings.Split(r.Port, "/")
|
||||
// port can be empty, so ignore conversion error
|
||||
value, _ := strconv.Atoi(split[0])
|
||||
port = &firewall.Port{}
|
||||
if value != 0 {
|
||||
port.Values = []int{value}
|
||||
}
|
||||
}
|
||||
|
||||
var protocol firewall.Protocol
|
||||
switch r.Protocol {
|
||||
case "tcp":
|
||||
protocol = firewall.ProtocolTCP
|
||||
case "udp":
|
||||
protocol = firewall.ProtocolUDP
|
||||
case "icmp":
|
||||
protocol = firewall.ProtocolICMP
|
||||
}
|
||||
|
||||
var direction firewall.Direction
|
||||
switch r.Direction {
|
||||
case "src":
|
||||
direction = firewall.DirectionSrc
|
||||
case "dst":
|
||||
direction = firewall.DirectionDst
|
||||
default:
|
||||
log.Error("invalid direction, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
var action firewall.Action
|
||||
switch r.Action {
|
||||
case "accept":
|
||||
action = firewall.ActionAccept
|
||||
case "drop":
|
||||
action = firewall.ActionDrop
|
||||
default:
|
||||
log.Error("invalid action, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rule, err := e.firewallManager.AddFiltering(ip, protocol, port, direction, action, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add firewall rule: %v", err)
|
||||
return nil
|
||||
}
|
||||
e.firewallRules[rule.GetRuleID()] = rule
|
||||
return rule
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
|
||||
11
client/internal/engine_stdnet.go
Normal file
11
client/internal/engine_stdnet.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
)
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet()
|
||||
}
|
||||
7
client/internal/engine_stdnet_android.go
Normal file
7
client/internal/engine_stdnet_android.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package internal
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(e.config.IFaceDiscover)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -30,7 +28,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmt "github.com/netbirdio/netbird/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -941,139 +938,6 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEngine_firewallManager(t *testing.T) {
|
||||
// TODO: enable when other platform will be added
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skipf("firewall manager not supported in the: %s", runtime.GOOS)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables"); err != nil {
|
||||
t.Skipf("iptables not found: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(CtxInitState(context.Background()), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
t.Errorf("copy temporary store file: %v", err)
|
||||
}
|
||||
|
||||
sigServer, signalAddr, err := startSignal()
|
||||
if err != nil {
|
||||
t.Errorf("start signal server: %v", err)
|
||||
return
|
||||
}
|
||||
defer sigServer.GracefulStop()
|
||||
|
||||
mgmtServer, mgmtAddr, err := startManagement(dir)
|
||||
if err != nil {
|
||||
t.Errorf("start management server: %v", err)
|
||||
return
|
||||
}
|
||||
defer mgmtServer.GracefulStop()
|
||||
|
||||
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
engine, err := createEngine(ctx, cancel, setupKey, 0, mgmtAddr, signalAddr)
|
||||
if err != nil {
|
||||
t.Errorf("create engine: %v", err)
|
||||
return
|
||||
}
|
||||
engine.dnsServer = &dns.MockServer{}
|
||||
|
||||
if err := engine.Start(); err != nil {
|
||||
t.Logf("start engine: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := engine.mgmClient.Close(); err != nil {
|
||||
t.Logf("close management client: %v", err)
|
||||
}
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
t.Logf("stop engine: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait 2 seconds until all management and signal processing will be finished
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
if engine.firewallManager == nil {
|
||||
t.Errorf("firewall manager is nil")
|
||||
return
|
||||
}
|
||||
|
||||
fwRules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerID: "test",
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
},
|
||||
{
|
||||
PeerID: "test2",
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: "dst",
|
||||
Action: "drop",
|
||||
Protocol: "udp",
|
||||
Port: "53",
|
||||
},
|
||||
}
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
engine.firewallRules = make(map[string]firewall.Rule)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
engine.applyFirewallRules(fwRules)
|
||||
|
||||
if len(engine.firewallRules) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", engine.firewallRules)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
// remove first rule
|
||||
fwRules = fwRules[1:]
|
||||
fwRules = append(fwRules, &mgmtProto.FirewallRule{
|
||||
PeerID: "test3",
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: "src",
|
||||
Action: "drop",
|
||||
Protocol: "icmp",
|
||||
})
|
||||
|
||||
existedRulesID := map[string]struct{}{}
|
||||
for id := range engine.firewallRules {
|
||||
existedRulesID[id] = struct{}{}
|
||||
}
|
||||
|
||||
engine.applyFirewallRules(fwRules)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(engine.firewallRules) != 2 {
|
||||
t.Errorf("firewall rules not applied")
|
||||
return
|
||||
}
|
||||
|
||||
// check that old rules was removed
|
||||
for id := range existedRulesID {
|
||||
if _, ok := engine.firewallRules[id]; ok {
|
||||
t.Errorf("old rule was not removed")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
func buildFirewallManager(wgIfaceName string) (fw firewall.Manager, err error) {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
)
|
||||
|
||||
// buildFirewallManager creates a firewall manager instance for the Linux
|
||||
func buildFirewallManager(wgIfaceName string) (firewall.Manager, error) {
|
||||
fw, err := iptables.Create(wgIfaceName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create iptables manager: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return fw, nil
|
||||
}
|
||||
@@ -35,15 +35,6 @@ type DeviceAuthInfo struct {
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// HostedGrantType grant type for device flow on Hosted
|
||||
const (
|
||||
HostedGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
@@ -52,16 +43,7 @@ const (
|
||||
|
||||
// Hosted client
|
||||
type Hosted struct {
|
||||
// Hosted API Audience for validation
|
||||
Audience string
|
||||
// Hosted Native application client id
|
||||
ClientID string
|
||||
// Hosted Native application request scope
|
||||
Scope string
|
||||
// TokenEndpoint to request access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint to request device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
providerConfig ProviderConfig
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
@@ -70,7 +52,7 @@ type Hosted struct {
|
||||
type RequestDeviceCodePayload struct {
|
||||
Audience string `json:"audience"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scope string `json:"scope"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
@@ -93,8 +75,26 @@ type Claims struct {
|
||||
Audience interface{} `json:"aud"`
|
||||
}
|
||||
|
||||
// TokenInfo holds information of issued access token
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
UseIDToken bool `json:"-"`
|
||||
}
|
||||
|
||||
// GetTokenToUse returns either the access or id token based on UseIDToken field
|
||||
func (t TokenInfo) GetTokenToUse() string {
|
||||
if t.UseIDToken {
|
||||
return t.IDToken
|
||||
}
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewHostedDeviceFlow returns an Hosted OAuth client
|
||||
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
|
||||
func NewHostedDeviceFlow(config ProviderConfig) *Hosted {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -104,27 +104,23 @@ func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string,
|
||||
}
|
||||
|
||||
return &Hosted{
|
||||
Audience: audience,
|
||||
ClientID: clientID,
|
||||
Scope: "openid",
|
||||
TokenEndpoint: tokenEndpoint,
|
||||
HTTPClient: httpClient,
|
||||
DeviceAuthEndpoint: deviceAuthEndpoint,
|
||||
providerConfig: config,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientID returns the provider client id
|
||||
func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||
return h.ClientID
|
||||
return h.providerConfig.ClientID
|
||||
}
|
||||
|
||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.ClientID)
|
||||
form.Add("audience", h.Audience)
|
||||
form.Add("scope", h.Scope)
|
||||
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
|
||||
form.Add("client_id", h.providerConfig.ClientID)
|
||||
form.Add("audience", h.providerConfig.Audience)
|
||||
form.Add("scope", h.providerConfig.Scope)
|
||||
req, err := http.NewRequest("POST", h.providerConfig.DeviceAuthEndpoint,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||
@@ -157,10 +153,10 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
|
||||
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.ClientID)
|
||||
form.Add("client_id", h.providerConfig.ClientID)
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", info.DeviceCode)
|
||||
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
req, err := http.NewRequest("POST", h.providerConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||
}
|
||||
@@ -225,18 +221,20 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
UseIDToken: h.providerConfig.UseIDToken,
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenInfo.GetTokenToUse(), h.providerConfig.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenInfo, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,14 +3,15 @@ package internal
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
@@ -113,12 +114,15 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
HTTPClient: &httpClient,
|
||||
providerConfig: ProviderConfig{
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
Scope: expectedScope,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||
@@ -275,12 +279,15 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
providerConfig: ProviderConfig{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
Scope: "openid",
|
||||
UseIDToken: false,
|
||||
},
|
||||
HTTPClient: &httpClient}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -9,15 +9,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v2"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/proxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// ConnConfig is a peer Connection configuration
|
||||
@@ -93,6 +93,9 @@ type Conn struct {
|
||||
proxy proxy.Proxy
|
||||
remoteModeCh chan ModeMessage
|
||||
meta meta
|
||||
|
||||
adapter iface.TunAdapter
|
||||
iFaceDiscover stdnet.IFaceDiscover
|
||||
}
|
||||
|
||||
// meta holds meta information about a connection
|
||||
@@ -118,7 +121,7 @@ func (conn *Conn) UpdateConf(conf ConnConfig) {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
|
||||
func NewConn(config ConnConfig, statusRecorder *Status, adapter iface.TunAdapter, iFaceDiscover stdnet.IFaceDiscover) (*Conn, error) {
|
||||
return &Conn{
|
||||
config: config,
|
||||
mu: sync.Mutex{},
|
||||
@@ -128,6 +131,8 @@ func NewConn(config ConnConfig, statusRecorder *Status) (*Conn, error) {
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
statusRecorder: statusRecorder,
|
||||
remoteModeCh: make(chan ModeMessage, 1),
|
||||
adapter: adapter,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -162,7 +167,9 @@ func (conn *Conn) reCreateAgent() error {
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
failedTimeout := 6 * time.Second
|
||||
transportNet, err := stdnet.NewNet()
|
||||
|
||||
var err error
|
||||
transportNet, err := conn.newStdNet()
|
||||
if err != nil {
|
||||
log.Warnf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
conn, err := NewConn(connConf, nil)
|
||||
conn, err := NewConn(connConf, nil, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) {
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"))
|
||||
conn, err := NewConn(connConf, NewRecorder("https://mgm"), nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func TestGetProxyWithMessageExchange(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
g := errgroup.Group{}
|
||||
conn, err := NewConn(connConf, nil)
|
||||
conn, err := NewConn(connConf, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -5,5 +5,7 @@ type Listener interface {
|
||||
OnConnected()
|
||||
OnDisconnected()
|
||||
OnConnecting()
|
||||
OnDisconnecting()
|
||||
OnAddressChanged(string, string)
|
||||
OnPeersListChanged(int)
|
||||
}
|
||||
|
||||
@@ -8,117 +8,135 @@ const (
|
||||
stateDisconnected = iota
|
||||
stateConnected
|
||||
stateConnecting
|
||||
stateDisconnecting
|
||||
)
|
||||
|
||||
type notifier struct {
|
||||
serverStateLock sync.Mutex
|
||||
listenersLock sync.Mutex
|
||||
listeners map[Listener]struct{}
|
||||
currentServerState bool
|
||||
listener Listener
|
||||
currentClientState bool
|
||||
lastNotification int
|
||||
}
|
||||
|
||||
func newNotifier() *notifier {
|
||||
return ¬ifier{
|
||||
listeners: make(map[Listener]struct{}),
|
||||
}
|
||||
return ¬ifier{}
|
||||
}
|
||||
|
||||
func (n *notifier) addListener(listener Listener) {
|
||||
func (n *notifier) setListener(listener Listener) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
n.serverStateLock.Lock()
|
||||
go n.notifyListener(listener, n.lastNotification)
|
||||
n.notifyListener(listener, n.lastNotification)
|
||||
n.serverStateLock.Unlock()
|
||||
n.listeners[listener] = struct{}{}
|
||||
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
func (n *notifier) removeListener(listener Listener) {
|
||||
func (n *notifier) removeListener() {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
delete(n.listeners, listener)
|
||||
n.listener = nil
|
||||
}
|
||||
|
||||
func (n *notifier) updateServerStates(mgmState bool, signalState bool) {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
|
||||
var newState bool
|
||||
if mgmState && signalState {
|
||||
newState = true
|
||||
} else {
|
||||
newState = false
|
||||
}
|
||||
calculatedState := n.calculateState(mgmState, signalState)
|
||||
|
||||
if !n.isServerStateChanged(newState) {
|
||||
if !n.isServerStateChanged(calculatedState) {
|
||||
return
|
||||
}
|
||||
|
||||
n.currentServerState = newState
|
||||
n.lastNotification = n.calculateState(newState, n.currentClientState)
|
||||
n.lastNotification = calculatedState
|
||||
|
||||
go n.notifyAll(n.lastNotification)
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStart() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = true
|
||||
n.lastNotification = n.calculateState(n.currentServerState, true)
|
||||
go n.notifyAll(n.lastNotification)
|
||||
n.lastNotification = stateConnected
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) clientStop() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = n.calculateState(n.currentServerState, false)
|
||||
go n.notifyAll(n.lastNotification)
|
||||
n.lastNotification = stateDisconnected
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) isServerStateChanged(newState bool) bool {
|
||||
return n.currentServerState != newState
|
||||
func (n *notifier) clientTearDown() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = false
|
||||
n.lastNotification = stateDisconnecting
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
func (n *notifier) notifyAll(state int) {
|
||||
func (n *notifier) isServerStateChanged(newState int) bool {
|
||||
return n.lastNotification != newState
|
||||
}
|
||||
|
||||
func (n *notifier) notify(state int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
for l := range n.listeners {
|
||||
n.notifyListener(l, state)
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
n.notifyListener(n.listener, state)
|
||||
}
|
||||
|
||||
func (n *notifier) notifyListener(l Listener, state int) {
|
||||
switch state {
|
||||
case stateDisconnected:
|
||||
l.OnDisconnected()
|
||||
case stateConnected:
|
||||
l.OnConnected()
|
||||
case stateConnecting:
|
||||
l.OnConnecting()
|
||||
}
|
||||
go func() {
|
||||
switch state {
|
||||
case stateDisconnected:
|
||||
l.OnDisconnected()
|
||||
case stateConnected:
|
||||
l.OnConnected()
|
||||
case stateConnecting:
|
||||
l.OnConnecting()
|
||||
case stateDisconnecting:
|
||||
l.OnDisconnecting()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (n *notifier) calculateState(serverState bool, clientState bool) int {
|
||||
if serverState && clientState {
|
||||
func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
||||
if managementConn && signalConn {
|
||||
return stateConnected
|
||||
}
|
||||
|
||||
if !clientState {
|
||||
if !managementConn && !signalConn {
|
||||
return stateDisconnected
|
||||
}
|
||||
|
||||
if n.lastNotification == stateDisconnecting {
|
||||
return stateDisconnecting
|
||||
}
|
||||
|
||||
return stateConnecting
|
||||
}
|
||||
|
||||
func (n *notifier) peerListChanged(numOfPeers int) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
|
||||
for l := range n.listeners {
|
||||
l.OnPeersListChanged(numOfPeers)
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
n.listener.OnPeersListChanged(numOfPeers)
|
||||
}
|
||||
|
||||
func (n *notifier) localAddressChanged(fqdn, address string) {
|
||||
n.listenersLock.Lock()
|
||||
defer n.listenersLock.Unlock()
|
||||
if n.listener == nil {
|
||||
return
|
||||
}
|
||||
n.listener.OnAddressChanged(fqdn, address)
|
||||
}
|
||||
|
||||
@@ -1,32 +1,97 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mocListener struct {
|
||||
lastState int
|
||||
wg sync.WaitGroup
|
||||
peers int
|
||||
}
|
||||
|
||||
func (l *mocListener) OnConnected() {
|
||||
l.lastState = stateConnected
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnDisconnected() {
|
||||
l.lastState = stateDisconnected
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnConnecting() {
|
||||
l.lastState = stateConnecting
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnDisconnecting() {
|
||||
l.lastState = stateDisconnecting
|
||||
l.wg.Done()
|
||||
}
|
||||
|
||||
func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||
|
||||
}
|
||||
func (l *mocListener) OnPeersListChanged(size int) {
|
||||
l.peers = size
|
||||
}
|
||||
|
||||
func (l *mocListener) setWaiter() {
|
||||
l.wg.Add(1)
|
||||
}
|
||||
|
||||
func (l *mocListener) wait() {
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
func Test_notifier_serverState(t *testing.T) {
|
||||
|
||||
type scenario struct {
|
||||
name string
|
||||
expected bool
|
||||
expected int
|
||||
mgmState bool
|
||||
signalState bool
|
||||
}
|
||||
scenarios := []scenario{
|
||||
{"connected", true, true, true},
|
||||
{"mgm down", false, false, true},
|
||||
{"signal down", false, true, false},
|
||||
{"disconnected", false, false, false},
|
||||
{"connected", stateConnected, true, true},
|
||||
{"mgm down", stateConnecting, false, true},
|
||||
{"signal down", stateConnecting, true, false},
|
||||
{"disconnected", stateDisconnected, false, false},
|
||||
}
|
||||
|
||||
for _, tt := range scenarios {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := newNotifier()
|
||||
n.updateServerStates(tt.mgmState, tt.signalState)
|
||||
if n.currentServerState != tt.expected {
|
||||
t.Errorf("invalid serverstate: %t, expected: %t", n.currentServerState, tt.expected)
|
||||
if n.lastNotification != tt.expected {
|
||||
t.Errorf("invalid serverstate: %d, expected: %d", n.lastNotification, tt.expected)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_notifier_SetListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
listener.wait()
|
||||
if listener.lastState != n.lastNotification {
|
||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_notifier_RemoveListener(t *testing.T) {
|
||||
listener := &mocListener{}
|
||||
listener.setWaiter()
|
||||
n := newNotifier()
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
n.removeListener()
|
||||
n.peerListChanged(1)
|
||||
|
||||
if listener.peers != 0 {
|
||||
t.Errorf("invalid state: %d", listener.peers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,6 +190,7 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = localPeerState
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// CleanLocalPeerState cleans local peer status
|
||||
@@ -198,6 +199,7 @@ func (d *Status) CleanLocalPeerState() {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.localPeer = LocalPeerState{}
|
||||
d.notifyAddressChanged()
|
||||
}
|
||||
|
||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||
@@ -215,7 +217,7 @@ func (d *Status) MarkManagementConnected() {
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.managementState = true
|
||||
d.managementState = true
|
||||
}
|
||||
|
||||
// UpdateSignalAddress update the address of the signal server
|
||||
@@ -238,7 +240,7 @@ func (d *Status) MarkSignalDisconnected() {
|
||||
defer d.mux.Unlock()
|
||||
defer d.onConnectionChanged()
|
||||
|
||||
d.signalState = false
|
||||
d.signalState = false
|
||||
}
|
||||
|
||||
// MarkSignalConnected sets SignalState to connected
|
||||
@@ -286,14 +288,19 @@ func (d *Status) ClientStop() {
|
||||
d.notifier.clientStop()
|
||||
}
|
||||
|
||||
// AddConnectionListener add a listener to the notifier
|
||||
func (d *Status) AddConnectionListener(listener Listener) {
|
||||
d.notifier.addListener(listener)
|
||||
// ClientTeardown will notify all listeners about the service is under teardown
|
||||
func (d *Status) ClientTeardown() {
|
||||
d.notifier.clientTearDown()
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove a listener from the notifier
|
||||
func (d *Status) RemoveConnectionListener(listener Listener) {
|
||||
d.notifier.removeListener(listener)
|
||||
// SetConnectionListener set a listener to the notifier
|
||||
func (d *Status) SetConnectionListener(listener Listener) {
|
||||
d.notifier.setListener(listener)
|
||||
}
|
||||
|
||||
// RemoveConnectionListener remove the listener from the notifier
|
||||
func (d *Status) RemoveConnectionListener() {
|
||||
d.notifier.removeListener()
|
||||
}
|
||||
|
||||
func (d *Status) onConnectionChanged() {
|
||||
@@ -303,3 +310,7 @@ func (d *Status) onConnectionChanged() {
|
||||
func (d *Status) notifyPeerListChanged() {
|
||||
d.notifier.peerListChanged(len(d.peers))
|
||||
}
|
||||
|
||||
func (d *Status) notifyAddressChanged() {
|
||||
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||
}
|
||||
|
||||
11
client/internal/peer/stdnet.go
Normal file
11
client/internal/peer/stdnet.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !android
|
||||
|
||||
package peer
|
||||
|
||||
import (
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
)
|
||||
|
||||
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet()
|
||||
}
|
||||
7
client/internal/peer/stdnet_android.go
Normal file
7
client/internal/peer/stdnet_android.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package peer
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
func (conn *Conn) newStdNet() (*stdnet.Net, error) {
|
||||
return stdnet.NewNet(conn.iFaceDiscover)
|
||||
}
|
||||
8
client/internal/stdnet/iface_discover.go
Normal file
8
client/internal/stdnet/iface_discover.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package stdnet
|
||||
|
||||
// IFaceDiscover provide an option for external services (mobile)
|
||||
// to collect network interface information
|
||||
type IFaceDiscover interface {
|
||||
// IFaces return with the description of the interfaces
|
||||
IFaces() (string, error)
|
||||
}
|
||||
137
client/internal/stdnet/stdnet.go
Normal file
137
client/internal/stdnet/stdnet.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Package stdnet is an extension of the pion's stdnet.
|
||||
// With it the list of the interface can come from external source.
|
||||
// More info: https://github.com/golang/go/issues/40569
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Net is an implementation of the net.Net interface
|
||||
// based on functions of the standard net package.
|
||||
type Net struct {
|
||||
stdnet.Net
|
||||
interfaces []*transport.Interface
|
||||
}
|
||||
|
||||
// NewNet creates a new StdNet instance.
|
||||
func NewNet(iFaceDiscover IFaceDiscover) (*Net, error) {
|
||||
n := &Net{}
|
||||
|
||||
return n, n.UpdateInterfaces(iFaceDiscover)
|
||||
}
|
||||
|
||||
// UpdateInterfaces updates the internal list of network interfaces
|
||||
// and associated addresses.
|
||||
func (n *Net) UpdateInterfaces(iFaceDiscover IFaceDiscover) error {
|
||||
ifacesString, err := iFaceDiscover.IFaces()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.interfaces = parseInterfacesString(ifacesString)
|
||||
return err
|
||||
}
|
||||
|
||||
// Interfaces returns a slice of interfaces which are available on the
|
||||
// system
|
||||
func (n *Net) Interfaces() ([]*transport.Interface, error) {
|
||||
return n.interfaces, nil
|
||||
}
|
||||
|
||||
// InterfaceByIndex returns the interface specified by index.
|
||||
//
|
||||
// On Solaris, it returns one of the logical network interfaces
|
||||
// sharing the logical data link; for more precision use
|
||||
// InterfaceByName.
|
||||
func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) {
|
||||
for _, ifc := range n.interfaces {
|
||||
if ifc.Index == index {
|
||||
return ifc, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index)
|
||||
}
|
||||
|
||||
// InterfaceByName returns the interface specified by name.
|
||||
func (n *Net) InterfaceByName(name string) (*transport.Interface, error) {
|
||||
for _, ifc := range n.interfaces {
|
||||
if ifc.Name == name {
|
||||
return ifc, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name)
|
||||
}
|
||||
|
||||
func parseInterfacesString(interfaces string) []*transport.Interface {
|
||||
ifs := []*transport.Interface{}
|
||||
|
||||
for _, iface := range strings.Split(interfaces, "\n") {
|
||||
if strings.TrimSpace(iface) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Split(iface, "|")
|
||||
if len(fields) != 2 {
|
||||
log.Warnf("parseInterfacesString: unable to split %q", iface)
|
||||
continue
|
||||
}
|
||||
|
||||
var name string
|
||||
var index, mtu int
|
||||
var up, broadcast, loopback, pointToPoint, multicast bool
|
||||
_, err := fmt.Sscanf(fields[0], "%s %d %d %t %t %t %t %t",
|
||||
&name, &index, &mtu, &up, &broadcast, &loopback, &pointToPoint, &multicast)
|
||||
if err != nil {
|
||||
log.Warnf("parseInterfacesString: unable to parse %q: %v", iface, err)
|
||||
continue
|
||||
}
|
||||
|
||||
newIf := net.Interface{
|
||||
Name: name,
|
||||
Index: index,
|
||||
MTU: mtu,
|
||||
}
|
||||
if up {
|
||||
newIf.Flags |= net.FlagUp
|
||||
}
|
||||
if broadcast {
|
||||
newIf.Flags |= net.FlagBroadcast
|
||||
}
|
||||
if loopback {
|
||||
newIf.Flags |= net.FlagLoopback
|
||||
}
|
||||
if pointToPoint {
|
||||
newIf.Flags |= net.FlagPointToPoint
|
||||
}
|
||||
if multicast {
|
||||
newIf.Flags |= net.FlagMulticast
|
||||
}
|
||||
|
||||
ifc := transport.NewInterface(newIf)
|
||||
|
||||
addrs := strings.Trim(fields[1], " \n")
|
||||
foundAddress := false
|
||||
for _, addr := range strings.Split(addrs, " ") {
|
||||
ip, ipNet, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
log.Warnf("%s", err)
|
||||
continue
|
||||
}
|
||||
ipNet.IP = ip
|
||||
ifc.AddAddress(ipNet)
|
||||
foundAddress = true
|
||||
}
|
||||
if foundAddress {
|
||||
ifs = append(ifs, ifc)
|
||||
}
|
||||
}
|
||||
return ifs
|
||||
}
|
||||
66
client/internal/stdnet/stdnet_test.go
Normal file
66
client/internal/stdnet/stdnet_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package stdnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_parseInterfacesString(t *testing.T) {
|
||||
testData := []struct {
|
||||
name string
|
||||
index int
|
||||
mtu int
|
||||
up bool
|
||||
broadcast bool
|
||||
loopBack bool
|
||||
pointToPoint bool
|
||||
multicast bool
|
||||
addr string
|
||||
}{
|
||||
{"wlan0", 30, 1500, true, true, false, false, true, "10.1.10.131/24"},
|
||||
{"rmnet0", 30, 1500, true, true, false, false, true, "192.168.0.56/24"},
|
||||
{"rmnet_data1", 30, 1500, true, true, false, false, true, "fec0::118c:faf7:8d97:3cb2/64"},
|
||||
}
|
||||
|
||||
var exampleString string
|
||||
for _, d := range testData {
|
||||
exampleString = fmt.Sprintf("%s\n%s %d %d %t %t %t %t %t | %s", exampleString,
|
||||
d.name,
|
||||
d.index,
|
||||
d.mtu,
|
||||
d.up,
|
||||
d.broadcast,
|
||||
d.loopBack,
|
||||
d.pointToPoint,
|
||||
d.multicast,
|
||||
d.addr)
|
||||
}
|
||||
nets := parseInterfacesString(exampleString)
|
||||
if len(nets) == 0 {
|
||||
t.Fatalf("failed to parse interfaces")
|
||||
}
|
||||
|
||||
for i, net := range nets {
|
||||
if net.MTU != testData[i].mtu {
|
||||
t.Errorf("invalid mtu: %d, expected: %d", net.MTU, testData[0].mtu)
|
||||
|
||||
}
|
||||
|
||||
if net.Interface.Name != testData[i].name {
|
||||
t.Errorf("invalid interface name: %s, expected: %s", net.Interface.Name, testData[i].name)
|
||||
}
|
||||
|
||||
addr, err := net.Addrs()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(addr) == 0 {
|
||||
t.Errorf("invalid address parsing")
|
||||
}
|
||||
|
||||
if addr[0].String() != testData[i].addr {
|
||||
t.Errorf("invalid address: %s, expected: %s", addr[0].String(), testData[i].addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func (s *Server) Start() error {
|
||||
// on failure we return error to retry
|
||||
config, err := internal.UpdateConfig(s.latestConfigInput)
|
||||
if errorStatus, ok := gstatus.FromError(err); ok && errorStatus.Code() == codes.NotFound {
|
||||
config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
|
||||
s.config, err = internal.UpdateOrCreateConfig(s.latestConfigInput)
|
||||
if err != nil {
|
||||
log.Warnf("unable to create configuration file: %v", err)
|
||||
return err
|
||||
@@ -102,7 +102,7 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil); err != nil {
|
||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil); err != nil {
|
||||
log.Errorf("init connections: %v", err)
|
||||
}
|
||||
}()
|
||||
@@ -223,12 +223,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
hostedClient := internal.NewHostedDeviceFlow(
|
||||
providerConfig.ProviderConfig.Audience,
|
||||
providerConfig.ProviderConfig.ClientID,
|
||||
providerConfig.ProviderConfig.TokenEndpoint,
|
||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||
)
|
||||
hostedClient := internal.NewHostedDeviceFlow(providerConfig.ProviderConfig)
|
||||
|
||||
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
@@ -344,7 +339,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
s.oauthAuthFlow.expiresAt = time.Now()
|
||||
s.mutex.Unlock()
|
||||
|
||||
if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.AccessToken); err != nil {
|
||||
if loginStatus, err := s.loginAttempt(ctx, "", tokenInfo.GetTokenToUse()); err != nil {
|
||||
state.Set(loginStatus)
|
||||
return nil, err
|
||||
}
|
||||
@@ -394,7 +389,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil); err != nil {
|
||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil); err != nil {
|
||||
log.Errorf("run client connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
func extractDeviceName(ctx context.Context) string {
|
||||
v, ok := ctx.Value(DeviceNameCtxKey).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
return "android"
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -3,10 +3,13 @@ package encryption
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const nonceSize = 24
|
||||
|
||||
// A set of tools to encrypt/decrypt messages being sent through the Signal Exchange Service or Management Service
|
||||
// These tools use Golang crypto package (Curve25519, XSalsa20 and Poly1305 to encrypt and authenticate)
|
||||
// Wireguard keys are used for encryption
|
||||
@@ -26,8 +29,11 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(nonce[:], encryptedMsg[:24])
|
||||
opened, ok := box.Open(nil, encryptedMsg[24:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
||||
if len(encryptedMsg) < nonceSize {
|
||||
return nil, fmt.Errorf("invalid encrypted message lenght")
|
||||
}
|
||||
copy(nonce[:], encryptedMsg[:nonceSize])
|
||||
opened, ok := box.Open(nil, encryptedMsg[nonceSize:], nonce, toByte32(peerPublicKey), toByte32(privateKey))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to decrypt message from peer %s", peerPublicKey.String())
|
||||
}
|
||||
@@ -36,8 +42,8 @@ func Decrypt(encryptedMsg []byte, peerPublicKey wgtypes.Key, privateKey wgtypes.
|
||||
}
|
||||
|
||||
// Generates nonce of size 24
|
||||
func genNonce() (*[24]byte, error) {
|
||||
var nonce [24]byte
|
||||
func genNonce() (*[nonceSize]byte, error) {
|
||||
var nonce [nonceSize]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -37,6 +37,7 @@ require (
|
||||
github.com/getlantern/systray v1.2.1
|
||||
github.com/gliderlabs/ssh v0.3.4
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/google/go-cmp v0.5.9
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
@@ -55,6 +56,7 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.33.0
|
||||
go.opentelemetry.io/otel/metric v0.33.0
|
||||
go.opentelemetry.io/otel/sdk/metric v0.33.0
|
||||
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf
|
||||
golang.org/x/net v0.8.0
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/term v0.6.0
|
||||
@@ -89,7 +91,6 @@ require (
|
||||
github.com/go-stack/stack v1.8.0 // indirect
|
||||
github.com/gobwas/glob v0.2.3 // indirect
|
||||
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/google/gopacket v1.1.19 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
@@ -126,7 +127,6 @@ require (
|
||||
go.opentelemetry.io/otel v1.11.1 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.11.1 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.11.1 // indirect
|
||||
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
|
||||
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
|
||||
golang.org/x/mod v0.8.0 // indirect
|
||||
golang.org/x/text v0.8.0 // indirect
|
||||
|
||||
@@ -33,7 +33,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
|
||||
if p.PresharedKey != nil {
|
||||
preSharedHexKey := hex.EncodeToString(p.PresharedKey[:])
|
||||
sb.WriteString(fmt.Sprintf("public_key=%s\n", preSharedHexKey))
|
||||
sb.WriteString(fmt.Sprintf("preshared_key=%s\n", preSharedHexKey))
|
||||
}
|
||||
|
||||
if p.Remove {
|
||||
|
||||
@@ -3,26 +3,31 @@
|
||||
# Management API
|
||||
|
||||
# Management API port
|
||||
NETBIRD_MGMT_API_PORT=33073
|
||||
NETBIRD_MGMT_API_PORT=${NETBIRD_MGMT_API_PORT:-33073}
|
||||
# Management API endpoint address, used by the Dashboard
|
||||
NETBIRD_MGMT_API_ENDPOINT=https://$NETBIRD_DOMAIN:$NETBIRD_MGMT_API_PORT
|
||||
# Management Certficate file path. These are generated by the Dashboard container
|
||||
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/fullchain.pem"
|
||||
NETBIRD_LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_API_CERT_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/fullchain.pem"
|
||||
# Management Certficate key file path.
|
||||
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_DOMAIN/privkey.pem"
|
||||
NETBIRD_MGMT_API_CERT_KEY_FILE="/etc/letsencrypt/live/$NETBIRD_LETSENCRYPT_DOMAIN/privkey.pem"
|
||||
# By default Management single account mode is enabled and domain set to $NETBIRD_DOMAIN, you may want to set this to your user's email domain
|
||||
NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN=$NETBIRD_DOMAIN
|
||||
NETBIRD_MGMT_DNS_DOMAIN=${NETBIRD_MGMT_DNS_DOMAIN:-netbird.selfhosted}
|
||||
# Turn credentials
|
||||
|
||||
# Signal
|
||||
NETBIRD_SIGNAL_PROTOCOL="http"
|
||||
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
|
||||
|
||||
# Turn credentials
|
||||
# User
|
||||
TURN_USER=self
|
||||
# Password. If empty, the configure.sh will generate one with openssl
|
||||
TURN_PASSWORD=
|
||||
# Min port
|
||||
TURN_MIN_PORT=49152
|
||||
TURN_MIN_PORT=${TURN_MIN_PORT:-49152}
|
||||
# Max port
|
||||
TURN_MAX_PORT=65535
|
||||
TURN_MAX_PORT=${TURN_MAX_PORT:-65535}
|
||||
|
||||
VOLUME_PREFIX="netbird-"
|
||||
MGMT_VOLUMESUFFIX="mgmt"
|
||||
@@ -30,8 +35,13 @@ SIGNAL_VOLUMESUFFIX="signal"
|
||||
LETSENCRYPT_VOLUMESUFFIX="letsencrypt"
|
||||
|
||||
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=${NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE:-$NETBIRD_AUTH_AUDIENCE}
|
||||
NETBIRD_AUTH_DEVICE_AUTH_SCOPE=${NETBIRD_AUTH_DEVICE_AUTH_SCOPE:-openid}
|
||||
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=${NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN:-false}
|
||||
|
||||
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false}
|
||||
NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
|
||||
|
||||
# exports
|
||||
export NETBIRD_DOMAIN
|
||||
@@ -44,6 +54,7 @@ export NETBIRD_AUTH_JWT_CERTS
|
||||
export NETBIRD_LETSENCRYPT_EMAIL
|
||||
export NETBIRD_MGMT_API_PORT
|
||||
export NETBIRD_MGMT_API_ENDPOINT
|
||||
export NETBIRD_LETSENCRYPT_DOMAIN
|
||||
export NETBIRD_MGMT_API_CERT_FILE
|
||||
export NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER
|
||||
@@ -61,4 +72,11 @@ export SIGNAL_VOLUMESUFFIX
|
||||
export LETSENCRYPT_VOLUMESUFFIX
|
||||
export NETBIRD_DISABLE_ANONYMOUS_METRICS
|
||||
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN
|
||||
export NETBIRD_MGMT_DNS_DOMAIN
|
||||
export NETBIRD_MGMT_DNS_DOMAIN
|
||||
export NETBIRD_SIGNAL_PROTOCOL
|
||||
export NETBIRD_SIGNAL_PORT
|
||||
export NETBIRD_AUTH_USER_ID_CLAIM
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
|
||||
export NETBIRD_TOKEN_SOURCE
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_SCOPE
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
||||
@@ -121,6 +121,30 @@ if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
||||
fi
|
||||
|
||||
# Check if letsencrypt was disabled
|
||||
if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]
|
||||
then
|
||||
export NETBIRD_DASHBOARD_ENDPOINT="https://$NETBIRD_DOMAIN:443"
|
||||
export NETBIRD_SIGNAL_ENDPOINT="https://$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT"
|
||||
|
||||
echo "Letsencrypt was disabled, the Https-endpoints cannot be used anymore"
|
||||
echo " and a reverse-proxy with Https needs to be placed in front of netbird!"
|
||||
echo "The following forwards have to be setup:"
|
||||
echo "- $NETBIRD_DASHBOARD_ENDPOINT -http-> dashboard:80"
|
||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/api -http-> management:$NETBIRD_MGMT_API_PORT"
|
||||
echo "- $NETBIRD_MGMT_API_ENDPOINT/management.ManagementService/ -grpc-> management:$NETBIRD_MGMT_API_PORT"
|
||||
echo "- $NETBIRD_SIGNAL_ENDPOINT/signalexchange.SignalExchange/ -grpc-> signal:80"
|
||||
echo "You most likely also have to change NETBIRD_MGMT_API_ENDPOINT in base.setup.env and port-mappings in docker-compose.yml.tmpl and rerun this script."
|
||||
echo " The target of the forwards depends on your setup. Beware of the gRPC protocol instead of http for management and signal!"
|
||||
echo "You are also free to remove any occurences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||
echo ""
|
||||
|
||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||
unset NETBIRD_MGMT_API_CERT_FILE
|
||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||
fi
|
||||
|
||||
env | grep NETBIRD
|
||||
|
||||
envsubst < docker-compose.yml.tmpl > docker-compose.yml
|
||||
|
||||
@@ -8,20 +8,26 @@ services:
|
||||
- 80:80
|
||||
- 443:443
|
||||
environment:
|
||||
# Endpoints
|
||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
# OIDC
|
||||
- AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
|
||||
- USE_AUTH0=$NETBIRD_USE_AUTH0
|
||||
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
- NGINX_SSL_PORT=443
|
||||
- LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
|
||||
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
|
||||
- NETBIRD_TOKEN_SOURCE=$NETBIRD_TOKEN_SOURCE
|
||||
# SSL
|
||||
- NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
volumes:
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
@@ -32,7 +38,8 @@ services:
|
||||
- 10000:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
|
||||
# Management
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
@@ -46,8 +53,15 @@ services:
|
||||
ports:
|
||||
- $NETBIRD_MGMT_API_PORT:443 #API port
|
||||
# # command for Let's Encrypt validation without dashboard container
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
|
||||
command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"]
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--port", "443",
|
||||
"--log-file", "console",
|
||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
]
|
||||
|
||||
# Coturn
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
@@ -60,6 +74,7 @@ services:
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
$SIGNAL_VOLUMENAME:
|
||||
|
||||
99
infrastructure_files/docker-compose.yml.tmpl.traefik
Normal file
99
infrastructure_files/docker-compose.yml.tmpl.traefik
Normal file
@@ -0,0 +1,99 @@
|
||||
version: "3"
|
||||
services:
|
||||
#UI dashboard
|
||||
dashboard:
|
||||
image: wiretrustee/dashboard:latest
|
||||
restart: unless-stopped
|
||||
#ports:
|
||||
# - 80:80
|
||||
# - 443:443
|
||||
environment:
|
||||
# Endpoints
|
||||
- NETBIRD_MGMT_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
- NETBIRD_MGMT_GRPC_API_ENDPOINT=$NETBIRD_MGMT_API_ENDPOINT
|
||||
# OIDC
|
||||
- AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
- AUTH_CLIENT_ID=$NETBIRD_AUTH_CLIENT_ID
|
||||
- AUTH_AUTHORITY=$NETBIRD_AUTH_AUTHORITY
|
||||
- USE_AUTH0=$NETBIRD_USE_AUTH0
|
||||
- AUTH_SUPPORTED_SCOPES=$NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
|
||||
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
|
||||
# SSL
|
||||
- NGINX_SSL_PORT=443
|
||||
# Letsencrypt
|
||||
- LETSENCRYPT_DOMAIN=$NETBIRD_LETSENCRYPT_DOMAIN
|
||||
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
|
||||
volumes:
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-dashboard.rule=Host(`$NETBIRD_DOMAIN`)
|
||||
- traefik.http.services.netbird-dashboard.loadbalancer.server.port=80
|
||||
|
||||
# Signal
|
||||
signal:
|
||||
image: netbirdio/signal:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
#ports:
|
||||
# - 10000:80
|
||||
# # port and command for Let's Encrypt validation
|
||||
# - 443:443
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=80
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||
|
||||
# Management
|
||||
management:
|
||||
image: netbirdio/management:latest
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- dashboard
|
||||
volumes:
|
||||
- $MGMT_VOLUMENAME:/var/lib/netbird
|
||||
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||
- ./management.json:/etc/netbird/management.json
|
||||
#ports:
|
||||
# - $NETBIRD_MGMT_API_PORT:443 #API port
|
||||
# # command for Let's Encrypt validation without dashboard container
|
||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||
command: [
|
||||
"--port", "443",
|
||||
"--log-file", "console",
|
||||
"--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS",
|
||||
"--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN",
|
||||
"--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN"
|
||||
]
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`)
|
||||
- traefik.http.routers.netbird-api.service=netbird-api
|
||||
- traefik.http.services.netbird-api.loadbalancer.server.port=443
|
||||
|
||||
- traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`)
|
||||
- traefik.http.routers.netbird-management.service=netbird-management
|
||||
- traefik.http.services.netbird-management.loadbalancer.server.port=443
|
||||
- traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c
|
||||
|
||||
# Coturn
|
||||
coturn:
|
||||
image: coturn/coturn
|
||||
restart: unless-stopped
|
||||
domainname: $NETBIRD_DOMAIN
|
||||
volumes:
|
||||
- ./turnserver.conf:/etc/turnserver.conf:ro
|
||||
# - ./privkey.pem:/etc/coturn/private/privkey.pem:ro
|
||||
# - ./cert.pem:/etc/coturn/certs/cert.pem:ro
|
||||
network_mode: host
|
||||
command:
|
||||
- -c /etc/turnserver.conf
|
||||
|
||||
volumes:
|
||||
$MGMT_VOLUMENAME:
|
||||
$SIGNAL_VOLUMENAME:
|
||||
$LETSENCRYPT_VOLUMENAME:
|
||||
@@ -21,8 +21,8 @@
|
||||
"TimeBasedCredentials": false
|
||||
},
|
||||
"Signal": {
|
||||
"Proto": "http",
|
||||
"URI": "$NETBIRD_DOMAIN:10000",
|
||||
"Proto": "$NETBIRD_SIGNAL_PROTOCOL",
|
||||
"URI": "$NETBIRD_DOMAIN:$NETBIRD_SIGNAL_PORT",
|
||||
"Username": "",
|
||||
"Password": null
|
||||
},
|
||||
@@ -43,11 +43,13 @@
|
||||
"DeviceAuthorizationFlow": {
|
||||
"Provider": "$NETBIRD_AUTH_DEVICE_AUTH_PROVIDER",
|
||||
"ProviderConfig": {
|
||||
"Audience": "$NETBIRD_AUTH_AUDIENCE",
|
||||
"Audience": "$NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE",
|
||||
"Domain": "$NETBIRD_AUTH0_DOMAIN",
|
||||
"ClientID": "$NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID",
|
||||
"TokenEndpoint": "$NETBIRD_AUTH_TOKEN_ENDPOINT",
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT"
|
||||
"DeviceAuthEndpoint": "$NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT",
|
||||
"Scope": "$NETBIRD_AUTH_DEVICE_AUTH_SCOPE",
|
||||
"UseIDToken": $NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,11 @@
|
||||
##
|
||||
# Dashboard domain. e.g. app.mydomain.com
|
||||
NETBIRD_DOMAIN=""
|
||||
# OIDC configuration e.g., https://example.eu.auth0.com/.well-known/openid-configuration
|
||||
|
||||
# -------------------------------------------
|
||||
# OIDC
|
||||
# e.g., https://example.eu.auth0.com/.well-known/openid-configuration
|
||||
# -------------------------------------------
|
||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
||||
NETBIRD_AUTH_AUDIENCE=""
|
||||
# e.g. netbird-client
|
||||
@@ -13,14 +17,30 @@ NETBIRD_AUTH_CLIENT_ID=""
|
||||
NETBIRD_USE_AUTH0="false"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
|
||||
# e.g. hello@mydomain.com
|
||||
NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
# Some IDPs requires different audience, scopes and to use id token for device authorization flow
|
||||
# you can customize here:
|
||||
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
|
||||
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_USE_ID_TOKEN=false
|
||||
|
||||
# if your IDP provider doesn't support fragmented URIs, configure custom
|
||||
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
|
||||
# NETBIRD_AUTH_REDIRECT_URI="/peers"
|
||||
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
|
||||
# Updates the preference to use id tokens instead of access token on dashboard
|
||||
# Okta and Gitlab IDPs can benefit from this
|
||||
# NETBIRD_TOKEN_SOURCE="idToken"
|
||||
|
||||
# -------------------------------------------
|
||||
# Letsencrypt
|
||||
# -------------------------------------------
|
||||
# Disable letsencrypt
|
||||
# if disabled, cannot use HTTPS anymore and requires setting up a reverse-proxy to do it instead
|
||||
NETBIRD_DISABLE_LETSENCRYPT=false
|
||||
# e.g. hello@mydomain.com
|
||||
NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
|
||||
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
|
||||
NETBIRD_DISABLE_ANONYMOUS_METRICS=false
|
||||
# DNS DOMAIN configures the domain name used for peer resolution. By default it is netbird.selfhosted
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
NETBIRD_MGMT_DNS_DOMAIN=netbird.selfhosted
|
||||
@@ -11,4 +11,9 @@ NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
|
||||
NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE
|
||||
# e.g. hello@mydomain.com
|
||||
NETBIRD_LETSENCRYPT_EMAIL=""
|
||||
NETBIRD_AUTH_REDIRECT_URI="/peers"
|
||||
NETBIRD_AUTH_REDIRECT_URI="/peers"
|
||||
NETBIRD_DISABLE_LETSENCRYPT=true
|
||||
NETBIRD_TOKEN_SOURCE="idToken"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE="super"
|
||||
NETBIRD_AUTH_USER_ID_CLAIM="email"
|
||||
NETBIRD_AUTH_DEVICE_AUTH_SCOPE="openid email"
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -144,15 +145,19 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, *serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
s, _ := gstatus.FromError(err)
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
default:
|
||||
backOff.Reset() // reset backoff counter after successful connection
|
||||
c.notifyDisconnected()
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
|
||||
// reset times and next try will start with a long delay
|
||||
backOff.Reset()
|
||||
c.notifyDisconnected()
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -19,25 +19,28 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity/sqlite"
|
||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity/sqlite"
|
||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// ManagementLegacyPort is the port that was used before by the Management gRPC server.
|
||||
@@ -179,13 +182,22 @@ var (
|
||||
tlsEnabled = true
|
||||
}
|
||||
|
||||
jwtValidator, err := jwtclaims.NewJWTValidator(
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating JWT validator: %v", err)
|
||||
}
|
||||
|
||||
httpAPIAuthCfg := httpapi.AuthCfg{
|
||||
Issuer: config.HttpConfig.AuthIssuer,
|
||||
Audience: config.HttpConfig.AuthAudience,
|
||||
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
|
||||
KeysLocation: config.HttpConfig.AuthKeysLocation,
|
||||
}
|
||||
httpAPIHandler, err := httpapi.APIHandler(accountManager, appMetrics, httpAPIAuthCfg)
|
||||
httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
@@ -405,6 +417,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
u.Host, config.DeviceAuthorizationFlow.ProviderConfig.Domain)
|
||||
config.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
|
||||
|
||||
if config.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
config.DeviceAuthorizationFlow.ProviderConfig.Scope = server.DefaultDeviceAuthFlowScope
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.0
|
||||
// protoc v3.21.9
|
||||
// source: management.proto
|
||||
|
||||
package proto
|
||||
@@ -995,8 +995,6 @@ type NetworkMap struct {
|
||||
DNSConfig *DNSConfig `protobuf:"bytes,6,opt,name=DNSConfig,proto3" json:"DNSConfig,omitempty"`
|
||||
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
|
||||
OfflinePeers []*RemotePeerConfig `protobuf:"bytes,7,rep,name=offlinePeers,proto3" json:"offlinePeers,omitempty"`
|
||||
// FirewallRule represents a list of firewall rules to be applied to peer
|
||||
FirewallRules []*FirewallRule `protobuf:"bytes,8,rep,name=FirewallRules,proto3" json:"FirewallRules,omitempty"`
|
||||
}
|
||||
|
||||
func (x *NetworkMap) Reset() {
|
||||
@@ -1080,13 +1078,6 @@ func (x *NetworkMap) GetOfflinePeers() []*RemotePeerConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *NetworkMap) GetFirewallRules() []*FirewallRule {
|
||||
if x != nil {
|
||||
return x.FirewallRules
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemotePeerConfig represents a configuration of a remote peer.
|
||||
// The properties are used to configure WireGuard Peers sections
|
||||
type RemotePeerConfig struct {
|
||||
@@ -1340,6 +1331,10 @@ type ProviderConfig struct {
|
||||
DeviceAuthEndpoint string `protobuf:"bytes,5,opt,name=DeviceAuthEndpoint,proto3" json:"DeviceAuthEndpoint,omitempty"`
|
||||
// TokenEndpoint is an endpoint to request auth token.
|
||||
TokenEndpoint string `protobuf:"bytes,6,opt,name=TokenEndpoint,proto3" json:"TokenEndpoint,omitempty"`
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string `protobuf:"bytes,7,opt,name=Scope,proto3" json:"Scope,omitempty"`
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool `protobuf:"varint,8,opt,name=UseIDToken,proto3" json:"UseIDToken,omitempty"`
|
||||
}
|
||||
|
||||
func (x *ProviderConfig) Reset() {
|
||||
@@ -1416,6 +1411,20 @@ func (x *ProviderConfig) GetTokenEndpoint() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ProviderConfig) GetScope() string {
|
||||
if x != nil {
|
||||
return x.Scope
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ProviderConfig) GetUseIDToken() bool {
|
||||
if x != nil {
|
||||
return x.UseIDToken
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Route represents a route.Route object
|
||||
type Route struct {
|
||||
state protoimpl.MessageState
|
||||
@@ -1840,94 +1849,6 @@ func (x *NameServer) GetPort() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// FirewallRule represents a firewall rule
|
||||
type FirewallRule struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
PeerID string `protobuf:"bytes,1,opt,name=PeerID,proto3" json:"PeerID,omitempty"`
|
||||
PeerIP string `protobuf:"bytes,2,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"`
|
||||
Direction string `protobuf:"bytes,3,opt,name=Direction,proto3" json:"Direction,omitempty"`
|
||||
Action string `protobuf:"bytes,4,opt,name=Action,proto3" json:"Action,omitempty"`
|
||||
Protocol string `protobuf:"bytes,5,opt,name=Protocol,proto3" json:"Protocol,omitempty"`
|
||||
Port string `protobuf:"bytes,6,opt,name=Port,proto3" json:"Port,omitempty"`
|
||||
}
|
||||
|
||||
func (x *FirewallRule) Reset() {
|
||||
*x = FirewallRule{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_management_proto_msgTypes[25]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *FirewallRule) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*FirewallRule) ProtoMessage() {}
|
||||
|
||||
func (x *FirewallRule) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_management_proto_msgTypes[25]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use FirewallRule.ProtoReflect.Descriptor instead.
|
||||
func (*FirewallRule) Descriptor() ([]byte, []int) {
|
||||
return file_management_proto_rawDescGZIP(), []int{25}
|
||||
}
|
||||
|
||||
func (x *FirewallRule) GetPeerID() string {
|
||||
if x != nil {
|
||||
return x.PeerID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *FirewallRule) GetPeerIP() string {
|
||||
if x != nil {
|
||||
return x.PeerIP
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *FirewallRule) GetDirection() string {
|
||||
if x != nil {
|
||||
return x.Direction
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *FirewallRule) GetAction() string {
|
||||
if x != nil {
|
||||
return x.Action
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *FirewallRule) GetProtocol() string {
|
||||
if x != nil {
|
||||
return x.Protocol
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *FirewallRule) GetPort() string {
|
||||
if x != nil {
|
||||
return x.Port
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_management_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_management_proto_rawDesc = []byte{
|
||||
@@ -2045,7 +1966,7 @@ var file_management_proto_rawDesc = []byte{
|
||||
0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
|
||||
0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f,
|
||||
0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xae, 0x03, 0x0a, 0x0a, 0x4e, 0x65, 0x74,
|
||||
0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xee, 0x02, 0x0a, 0x0a, 0x4e, 0x65, 0x74,
|
||||
0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61,
|
||||
0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12,
|
||||
0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20,
|
||||
@@ -2068,11 +1989,7 @@ var file_management_proto_rawDesc = []byte{
|
||||
0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e,
|
||||
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74,
|
||||
0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66,
|
||||
0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72,
|
||||
0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b,
|
||||
0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69,
|
||||
0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65,
|
||||
0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65,
|
||||
0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65,
|
||||
0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a,
|
||||
0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c,
|
||||
@@ -2101,7 +2018,7 @@ var file_management_proto_rawDesc = []byte{
|
||||
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76,
|
||||
0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72,
|
||||
0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44,
|
||||
0x10, 0x00, 0x22, 0xda, 0x01, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43,
|
||||
0x10, 0x00, 0x22, 0x90, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
|
||||
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49,
|
||||
0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65,
|
||||
@@ -2114,91 +2031,84 @@ var file_management_proto_rawDesc = []byte{
|
||||
0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74,
|
||||
0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b,
|
||||
0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22,
|
||||
0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74,
|
||||
0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77,
|
||||
0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79,
|
||||
0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72,
|
||||
0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74,
|
||||
0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69,
|
||||
0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18,
|
||||
0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64,
|
||||
0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
|
||||
0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e,
|
||||
0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18,
|
||||
0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f,
|
||||
0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72,
|
||||
0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
|
||||
0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
|
||||
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e,
|
||||
0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58,
|
||||
0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06,
|
||||
0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f,
|
||||
0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18,
|
||||
0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52,
|
||||
0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70,
|
||||
0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65,
|
||||
0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20,
|
||||
0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74,
|
||||
0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0x7f,
|
||||
0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75,
|
||||
0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73,
|
||||
0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
|
||||
0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b,
|
||||
0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50,
|
||||
0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72,
|
||||
0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73,
|
||||
0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22,
|
||||
0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a,
|
||||
0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a,
|
||||
0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e,
|
||||
0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20,
|
||||
0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa4, 0x01, 0x0a, 0x0c, 0x46, 0x69,
|
||||
0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65,
|
||||
0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72,
|
||||
0x49, 0x44, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x02, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x1c, 0x0a, 0x09, 0x44, 0x69,
|
||||
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x44,
|
||||
0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69,
|
||||
0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x12, 0x1a, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x50, 0x6f, 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74,
|
||||
0x32, 0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
|
||||
0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63,
|
||||
0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e,
|
||||
0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79,
|
||||
0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a,
|
||||
0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74,
|
||||
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
|
||||
0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76,
|
||||
0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
|
||||
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52,
|
||||
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48,
|
||||
0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d,
|
||||
0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
|
||||
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a,
|
||||
0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f,
|
||||
0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d,
|
||||
0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12,
|
||||
0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
||||
0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f,
|
||||
0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44,
|
||||
0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xb5, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12,
|
||||
0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12,
|
||||
0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74,
|
||||
0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b,
|
||||
0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50,
|
||||
0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12,
|
||||
0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52,
|
||||
0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75,
|
||||
0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73,
|
||||
0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44,
|
||||
0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x22, 0xb4, 0x01,
|
||||
0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c,
|
||||
0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47,
|
||||
0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61,
|
||||
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75,
|
||||
0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32,
|
||||
0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73,
|
||||
0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a,
|
||||
0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f,
|
||||
0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65,
|
||||
0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61,
|
||||
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52,
|
||||
0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74,
|
||||
0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12,
|
||||
0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61,
|
||||
0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03,
|
||||
0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18,
|
||||
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03,
|
||||
0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14,
|
||||
0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52,
|
||||
0x44, 0x61, 0x74, 0x61, 0x22, 0x7f, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76,
|
||||
0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d,
|
||||
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
|
||||
0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01,
|
||||
0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44,
|
||||
0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f,
|
||||
0x6d, 0x61, 0x69, 0x6e, 0x73, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50,
|
||||
0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x32,
|
||||
0xf7, 0x02, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c,
|
||||
0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72,
|
||||
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d,
|
||||
0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
|
||||
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e,
|
||||
0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
|
||||
0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04,
|
||||
0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
|
||||
0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61,
|
||||
0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65,
|
||||
0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e,
|
||||
0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65,
|
||||
0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65,
|
||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65,
|
||||
0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67,
|
||||
0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a,
|
||||
0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72,
|
||||
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61,
|
||||
0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
|
||||
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61,
|
||||
0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
|
||||
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -2214,7 +2124,7 @@ func file_management_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||
var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 26)
|
||||
var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 25)
|
||||
var file_management_proto_goTypes = []interface{}{
|
||||
(HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol
|
||||
(DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider
|
||||
@@ -2243,8 +2153,7 @@ var file_management_proto_goTypes = []interface{}{
|
||||
(*SimpleRecord)(nil), // 24: management.SimpleRecord
|
||||
(*NameServerGroup)(nil), // 25: management.NameServerGroup
|
||||
(*NameServer)(nil), // 26: management.NameServer
|
||||
(*FirewallRule)(nil), // 27: management.FirewallRule
|
||||
(*timestamppb.Timestamp)(nil), // 28: google.protobuf.Timestamp
|
||||
(*timestamppb.Timestamp)(nil), // 27: google.protobuf.Timestamp
|
||||
}
|
||||
var file_management_proto_depIdxs = []int32{
|
||||
11, // 0: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
|
||||
@@ -2255,7 +2164,7 @@ var file_management_proto_depIdxs = []int32{
|
||||
6, // 5: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
|
||||
11, // 6: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
|
||||
14, // 7: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
|
||||
28, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
|
||||
27, // 8: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
|
||||
12, // 9: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
|
||||
13, // 10: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
|
||||
12, // 11: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
|
||||
@@ -2267,29 +2176,28 @@ var file_management_proto_depIdxs = []int32{
|
||||
21, // 17: management.NetworkMap.Routes:type_name -> management.Route
|
||||
22, // 18: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
|
||||
16, // 19: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
|
||||
27, // 20: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
|
||||
17, // 21: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
|
||||
1, // 22: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
|
||||
20, // 23: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
|
||||
25, // 24: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
|
||||
23, // 25: management.DNSConfig.CustomZones:type_name -> management.CustomZone
|
||||
24, // 26: management.CustomZone.Records:type_name -> management.SimpleRecord
|
||||
26, // 27: management.NameServerGroup.NameServers:type_name -> management.NameServer
|
||||
2, // 28: management.ManagementService.Login:input_type -> management.EncryptedMessage
|
||||
2, // 29: management.ManagementService.Sync:input_type -> management.EncryptedMessage
|
||||
10, // 30: management.ManagementService.GetServerKey:input_type -> management.Empty
|
||||
10, // 31: management.ManagementService.isHealthy:input_type -> management.Empty
|
||||
2, // 32: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
|
||||
2, // 33: management.ManagementService.Login:output_type -> management.EncryptedMessage
|
||||
2, // 34: management.ManagementService.Sync:output_type -> management.EncryptedMessage
|
||||
9, // 35: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
|
||||
10, // 36: management.ManagementService.isHealthy:output_type -> management.Empty
|
||||
2, // 37: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
|
||||
33, // [33:38] is the sub-list for method output_type
|
||||
28, // [28:33] is the sub-list for method input_type
|
||||
28, // [28:28] is the sub-list for extension type_name
|
||||
28, // [28:28] is the sub-list for extension extendee
|
||||
0, // [0:28] is the sub-list for field type_name
|
||||
17, // 20: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
|
||||
1, // 21: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
|
||||
20, // 22: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
|
||||
25, // 23: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
|
||||
23, // 24: management.DNSConfig.CustomZones:type_name -> management.CustomZone
|
||||
24, // 25: management.CustomZone.Records:type_name -> management.SimpleRecord
|
||||
26, // 26: management.NameServerGroup.NameServers:type_name -> management.NameServer
|
||||
2, // 27: management.ManagementService.Login:input_type -> management.EncryptedMessage
|
||||
2, // 28: management.ManagementService.Sync:input_type -> management.EncryptedMessage
|
||||
10, // 29: management.ManagementService.GetServerKey:input_type -> management.Empty
|
||||
10, // 30: management.ManagementService.isHealthy:input_type -> management.Empty
|
||||
2, // 31: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
|
||||
2, // 32: management.ManagementService.Login:output_type -> management.EncryptedMessage
|
||||
2, // 33: management.ManagementService.Sync:output_type -> management.EncryptedMessage
|
||||
9, // 34: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
|
||||
10, // 35: management.ManagementService.isHealthy:output_type -> management.Empty
|
||||
2, // 36: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
|
||||
32, // [32:37] is the sub-list for method output_type
|
||||
27, // [27:32] is the sub-list for method input_type
|
||||
27, // [27:27] is the sub-list for extension type_name
|
||||
27, // [27:27] is the sub-list for extension extendee
|
||||
0, // [0:27] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_management_proto_init() }
|
||||
@@ -2598,18 +2506,6 @@ func file_management_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*FirewallRule); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
@@ -2617,7 +2513,7 @@ func file_management_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_management_proto_rawDesc,
|
||||
NumEnums: 2,
|
||||
NumMessages: 26,
|
||||
NumMessages: 25,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -186,9 +186,6 @@ message NetworkMap {
|
||||
|
||||
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
|
||||
repeated RemotePeerConfig offlinePeers = 7;
|
||||
|
||||
// FirewallRule represents a list of firewall rules to be applied to peer
|
||||
repeated FirewallRule FirewallRules = 8;
|
||||
}
|
||||
|
||||
// RemotePeerConfig represents a configuration of a remote peer.
|
||||
@@ -249,6 +246,10 @@ message ProviderConfig {
|
||||
string DeviceAuthEndpoint = 5;
|
||||
// TokenEndpoint is an endpoint to request auth token.
|
||||
string TokenEndpoint = 6;
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
string Scope = 7;
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
bool UseIDToken = 8;
|
||||
}
|
||||
|
||||
// Route represents a route.Route object
|
||||
@@ -296,14 +297,4 @@ message NameServer {
|
||||
string IP = 1;
|
||||
int64 NSType = 2;
|
||||
int64 Port = 3;
|
||||
}
|
||||
|
||||
// FirewallRule represents a firewall rule
|
||||
message FirewallRule {
|
||||
string PeerID = 1;
|
||||
string PeerIP = 2;
|
||||
string Direction = 3;
|
||||
string Action = 4;
|
||||
string Protocol = 5;
|
||||
string Port = 6;
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -12,17 +15,19 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"codeberg.org/ac/base62"
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -50,6 +55,8 @@ type AccountManager interface {
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
MarkPATUsed(tokenID string) error
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeerByKey(peerKey string) (*Peer, error)
|
||||
@@ -61,6 +68,10 @@ type AccountManager interface {
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
GetPeerNetwork(peerID string) (*Network, error)
|
||||
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
|
||||
CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||
GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||
UpdatePeerSSHKey(peerID string, sshKey string) error
|
||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(accountId, groupID string) (*Group, error)
|
||||
@@ -272,7 +283,7 @@ func (a *Account) GetGroup(groupID string) *Group {
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
||||
aclPeers, firewallRules := a.getPeersByPolicy(peerID)
|
||||
aclPeers := a.getPeersByACL(peerID)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*Peer
|
||||
var expiredPeers []*Peer
|
||||
@@ -303,12 +314,11 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
Peers: peersToConnect,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,11 +362,11 @@ func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true
|
||||
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user
|
||||
func (a *Account) GetPeersWithExpiration() []*Peer {
|
||||
peers := make([]*Peer, 0)
|
||||
for _, peer := range a.Peers {
|
||||
if peer.LoginExpirationEnabled {
|
||||
if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
}
|
||||
@@ -1113,6 +1123,87 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
|
||||
unlock := am.Store.AcquireGlobalLock()
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pat, ok := account.Users[user.Id].PATs[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
pat.LastUsed = time.Now().UTC()
|
||||
|
||||
return am.Store.SaveAccount(account)
|
||||
}
|
||||
|
||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) {
|
||||
if len(token) != PATLength {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong length")
|
||||
}
|
||||
|
||||
prefix := token[:len(PATPrefix)]
|
||||
if prefix != PATPrefix {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
}
|
||||
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
|
||||
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(tokenID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, nil, nil, fmt.Errorf("personal access token not found")
|
||||
}
|
||||
|
||||
return account, user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountFromToken returns an account associated with this token
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
||||
if claims.UserId == "" {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
@@ -125,12 +127,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
Name: peerID1,
|
||||
DNSLabel: peerID1,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: false,
|
||||
LoginExpired: true,
|
||||
},
|
||||
UserID: userID,
|
||||
LastLogin: time.Now().Add(-time.Hour * 24 * 30 * 30),
|
||||
LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
|
||||
},
|
||||
"peer-2": {
|
||||
ID: peerID2,
|
||||
@@ -139,12 +141,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
Name: peerID2,
|
||||
DNSLabel: peerID2,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: false,
|
||||
LoginExpired: false,
|
||||
},
|
||||
UserID: userID,
|
||||
LastLogin: time.Now(),
|
||||
LastLogin: time.Now().UTC(),
|
||||
LoginExpirationEnabled: true,
|
||||
},
|
||||
},
|
||||
@@ -163,12 +165,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
Name: peerID1,
|
||||
DNSLabel: peerID1,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: false,
|
||||
LoginExpired: true,
|
||||
},
|
||||
UserID: userID,
|
||||
LastLogin: time.Now().Add(-time.Hour * 24 * 30 * 30),
|
||||
LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
|
||||
LoginExpirationEnabled: true,
|
||||
},
|
||||
"peer-2": {
|
||||
@@ -178,12 +180,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
Name: peerID2,
|
||||
DNSLabel: peerID2,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: false,
|
||||
LoginExpired: true,
|
||||
},
|
||||
UserID: userID,
|
||||
LastLogin: time.Now().Add(-time.Hour * 24 * 30 * 30),
|
||||
LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
|
||||
LoginExpirationEnabled: true,
|
||||
},
|
||||
},
|
||||
@@ -458,6 +460,79 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", account.Id)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
}
|
||||
|
||||
err = am.MarkPATUsed("tokenId")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount("account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -1208,15 +1283,15 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Id: "user1",
|
||||
Role: UserRoleAdmin,
|
||||
AutoGroups: []string{"group1"},
|
||||
PATs: []PersonalAccessToken{
|
||||
{
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
Description: "First PAT",
|
||||
Name: "First PAT",
|
||||
HashedToken: "SoMeHaShEdToKeN",
|
||||
ExpirationDate: time.Now().AddDate(0, 0, 7),
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: "user1",
|
||||
CreatedAt: time.Now(),
|
||||
LastUsed: time.Now(),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1494,22 +1569,22 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
ID: "peer-1",
|
||||
LoginExpirationEnabled: true,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LastLogin: time.Now().Add(-30 * time.Minute),
|
||||
LastLogin: time.Now().UTC().Add(-30 * time.Minute),
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
ID: "peer-2",
|
||||
LoginExpirationEnabled: true,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LastLogin: time.Now().Add(-2 * time.Hour),
|
||||
LastLogin: time.Now().UTC().Add(-2 * time.Hour),
|
||||
UserID: userID,
|
||||
},
|
||||
|
||||
@@ -1517,11 +1592,11 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
ID: "peer-3",
|
||||
LoginExpirationEnabled: true,
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
LastSeen: time.Now().UTC(),
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LastLogin: time.Now().Add(-1 * time.Hour),
|
||||
LastLogin: time.Now().UTC().Add(-1 * time.Hour),
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
@@ -1571,9 +1646,11 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||
peers: map[string]*Peer{
|
||||
"peer-1": {
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expectedPeers: map[string]struct{}{},
|
||||
@@ -1584,9 +1661,11 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
|
||||
"peer-1": {
|
||||
ID: "peer-1",
|
||||
LoginExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expectedPeers: map[string]struct{}{
|
||||
@@ -1646,12 +1725,14 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
Connected: false,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &PeerStatus{
|
||||
Connected: true,
|
||||
},
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
@@ -1667,12 +1748,14 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
Connected: true,
|
||||
},
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &PeerStatus{
|
||||
Connected: true,
|
||||
},
|
||||
LoginExpirationEnabled: false,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
@@ -1689,6 +1772,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
LoginExpired: true,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &PeerStatus{
|
||||
@@ -1696,6 +1780,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
LoginExpired: true,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
@@ -1712,7 +1797,8 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
LoginExpired: false,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
LastLogin: time.Now(),
|
||||
LastLogin: time.Now().UTC(),
|
||||
UserID: userID,
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &PeerStatus{
|
||||
@@ -1720,6 +1806,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
LoginExpired: true,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
UserID: userID,
|
||||
},
|
||||
},
|
||||
expiration: time.Minute,
|
||||
@@ -1727,6 +1814,31 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
||||
expectedNextRun: true,
|
||||
expectedNextExpiration: expectedNextExpiration,
|
||||
},
|
||||
{
|
||||
name: "Peers added with setup keys, no expiration",
|
||||
peers: map[string]*Peer{
|
||||
"peer-1": {
|
||||
Status: &PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
"peer-2": {
|
||||
Status: &PeerStatus{
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
LoginExpirationEnabled: true,
|
||||
SetupKey: "key",
|
||||
},
|
||||
},
|
||||
expiration: time.Second,
|
||||
expirationEnabled: false,
|
||||
expectedNextRun: false,
|
||||
expectedNextExpiration: time.Duration(0),
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
@@ -83,6 +83,10 @@ const (
|
||||
AccountPeerLoginExpirationDisabled
|
||||
// AccountPeerLoginExpirationDurationUpdated indicates that a user updated peer login expiration duration for the account
|
||||
AccountPeerLoginExpirationDurationUpdated
|
||||
// PersonalAccessTokenCreated indicates that a user created a personal access token
|
||||
PersonalAccessTokenCreated
|
||||
// PersonalAccessTokenDeleted indicates that a user deleted a personal access token
|
||||
PersonalAccessTokenDeleted
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -168,6 +172,10 @@ const (
|
||||
AccountPeerLoginExpirationDisabledMessage string = "Peer login expiration disabled for the account"
|
||||
// AccountPeerLoginExpirationDurationUpdatedMessage is a human-readable text message of the AccountPeerLoginExpirationDurationUpdated activity
|
||||
AccountPeerLoginExpirationDurationUpdatedMessage string = "Peer login expiration duration updated"
|
||||
// PersonalAccessTokenCreatedMessage is a human-readable text message of the PersonalAccessTokenCreated activity
|
||||
PersonalAccessTokenCreatedMessage string = "Personal access token created"
|
||||
// PersonalAccessTokenDeletedMessage is a human-readable text message of the PersonalAccessTokenDeleted activity
|
||||
PersonalAccessTokenDeletedMessage string = "Personal access token deleted"
|
||||
)
|
||||
|
||||
// Activity that triggered an Event
|
||||
@@ -258,6 +266,10 @@ func (a Activity) Message() string {
|
||||
return AccountPeerLoginExpirationDisabledMessage
|
||||
case AccountPeerLoginExpirationDurationUpdated:
|
||||
return AccountPeerLoginExpirationDurationUpdatedMessage
|
||||
case PersonalAccessTokenCreated:
|
||||
return PersonalAccessTokenCreatedMessage
|
||||
case PersonalAccessTokenDeleted:
|
||||
return PersonalAccessTokenDeletedMessage
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
@@ -348,6 +360,10 @@ func (a Activity) StringCode() string {
|
||||
return "account.setting.peer.login.expiration.enable"
|
||||
case AccountPeerLoginExpirationDisabled:
|
||||
return "account.setting.peer.login.expiration.disable"
|
||||
case PersonalAccessTokenCreated:
|
||||
return "personal.access.token.create"
|
||||
case PersonalAccessTokenDeleted:
|
||||
return "personal.access.token.delete"
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
func TestNewSQLiteStore(t *testing.T) {
|
||||
@@ -15,13 +17,13 @@ func TestNewSQLiteStore(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer store.Close() //nolint
|
||||
defer store.Close() //nolint
|
||||
|
||||
accountID := "account_1"
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = store.Save(&activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user_" + fmt.Sprint(i),
|
||||
TargetID: "peer_" + fmt.Sprint(i),
|
||||
|
||||
@@ -24,6 +24,11 @@ const (
|
||||
NONE Provider = "none"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow
|
||||
DefaultDeviceAuthFlowScope string = "openid"
|
||||
)
|
||||
|
||||
// Config of the Management service
|
||||
type Config struct {
|
||||
Stuns []*Host
|
||||
@@ -39,6 +44,17 @@ type Config struct {
|
||||
DeviceAuthorizationFlow *DeviceAuthorizationFlow
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
func (c Config) GetAuthAudiences() []string {
|
||||
audiences := []string{c.HttpConfig.AuthAudience}
|
||||
|
||||
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
|
||||
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
|
||||
}
|
||||
|
||||
return audiences
|
||||
}
|
||||
|
||||
// TURNConfig is a config of the TURNCredentialsManager
|
||||
type TURNConfig struct {
|
||||
TimeBasedCredentials bool
|
||||
@@ -98,6 +114,10 @@ type ProviderConfig struct {
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
}
|
||||
|
||||
// validateURL validates input http url
|
||||
|
||||
@@ -2,9 +2,11 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
@@ -39,7 +41,7 @@ func (am *DefaultAccountManager) storeEvent(initiatorID, targetID, accountID str
|
||||
|
||||
go func() {
|
||||
_, err := am.eventStore.Save(&activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activityID,
|
||||
InitiatorID: initiatorID,
|
||||
TargetID: targetID,
|
||||
@@ -47,7 +49,7 @@ func (am *DefaultAccountManager) storeEvent(initiatorID, targetID, accountID str
|
||||
Meta: meta,
|
||||
})
|
||||
if err != nil {
|
||||
//todo add metric
|
||||
// todo add metric
|
||||
log.Errorf("received an error while storing an activity event, error: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ activity.Activity, initiatorID, targetID,
|
||||
accountID string, count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
_, err := manager.eventStore.Save(&activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: typ,
|
||||
InitiatorID: initiatorID,
|
||||
TargetID: targetID,
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -25,6 +26,8 @@ type FileStore struct {
|
||||
PeerID2AccountID map[string]string `json:"-"`
|
||||
UserID2AccountID map[string]string `json:"-"`
|
||||
PrivateDomain2AccountID map[string]string `json:"-"`
|
||||
HashedPAT2TokenID map[string]string `json:"-"`
|
||||
TokenID2UserID map[string]string `json:"-"`
|
||||
InstallationID string
|
||||
|
||||
// mutex to synchronise Store read/write operations
|
||||
@@ -57,6 +60,8 @@ func restore(file string) (*FileStore, error) {
|
||||
UserID2AccountID: make(map[string]string),
|
||||
PrivateDomain2AccountID: make(map[string]string),
|
||||
PeerID2AccountID: make(map[string]string),
|
||||
HashedPAT2TokenID: make(map[string]string),
|
||||
TokenID2UserID: make(map[string]string),
|
||||
storeFile: file,
|
||||
}
|
||||
|
||||
@@ -80,6 +85,8 @@ func restore(file string) (*FileStore, error) {
|
||||
store.UserID2AccountID = make(map[string]string)
|
||||
store.PrivateDomain2AccountID = make(map[string]string)
|
||||
store.PeerID2AccountID = make(map[string]string)
|
||||
store.HashedPAT2TokenID = make(map[string]string)
|
||||
store.TokenID2UserID = make(map[string]string)
|
||||
|
||||
for accountID, account := range store.Accounts {
|
||||
if account.Settings == nil {
|
||||
@@ -103,9 +110,10 @@ func restore(file string) (*FileStore, error) {
|
||||
}
|
||||
for _, user := range account.Users {
|
||||
store.UserID2AccountID[user.Id] = accountID
|
||||
}
|
||||
for _, user := range account.Users {
|
||||
store.UserID2AccountID[user.Id] = accountID
|
||||
for _, pat := range user.PATs {
|
||||
store.TokenID2UserID[pat.ID] = user.Id
|
||||
store.HashedPAT2TokenID[pat.HashedToken] = pat.ID
|
||||
}
|
||||
}
|
||||
|
||||
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
|
||||
@@ -113,15 +121,25 @@ func restore(file string) (*FileStore, error) {
|
||||
store.PrivateDomain2AccountID[account.Domain] = accountID
|
||||
}
|
||||
|
||||
// if no policies are defined, that means we need to migrate Rules to policies
|
||||
if len(account.Policies) == 0 {
|
||||
// TODO: policy query generated from the Go template and rule object.
|
||||
// We need to refactor this part to avoid using templating for policies queries building
|
||||
// and drop this migration part.
|
||||
policies := make(map[string]int, len(account.Policies))
|
||||
for i, policy := range account.Policies {
|
||||
policies[policy.ID] = i
|
||||
}
|
||||
if account.Policies == nil {
|
||||
account.Policies = make([]*Policy, 0)
|
||||
for _, rule := range account.Rules {
|
||||
policy, err := RuleToPolicy(rule)
|
||||
if err != nil {
|
||||
log.Errorf("unable to migrate rule to policy: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, rule := range account.Rules {
|
||||
policy, err := RuleToPolicy(rule)
|
||||
if err != nil {
|
||||
log.Errorf("unable to migrate rule to policy: %v", err)
|
||||
continue
|
||||
}
|
||||
if i, ok := policies[policy.ID]; ok {
|
||||
account.Policies[i] = policy
|
||||
} else {
|
||||
account.Policies = append(account.Policies, policy)
|
||||
}
|
||||
}
|
||||
@@ -155,7 +173,7 @@ func restore(file string) (*FileStore, error) {
|
||||
for key, peer := range account.Peers {
|
||||
// set LastLogin for the peers that were onboarded before the peer login expiration feature
|
||||
if peer.LastLogin.IsZero() {
|
||||
peer.LastLogin = time.Now()
|
||||
peer.LastLogin = time.Now().UTC()
|
||||
}
|
||||
if peer.ID != "" {
|
||||
continue
|
||||
@@ -258,15 +276,17 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
|
||||
for _, user := range accountCopy.Users {
|
||||
s.UserID2AccountID[user.Id] = accountCopy.Id
|
||||
for _, pat := range user.PATs {
|
||||
s.TokenID2UserID[pat.ID] = user.Id
|
||||
s.HashedPAT2TokenID[pat.HashedToken] = pat.ID
|
||||
}
|
||||
}
|
||||
|
||||
if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
|
||||
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
|
||||
}
|
||||
|
||||
if accountCopy.Rules == nil {
|
||||
accountCopy.Rules = make(map[string]*Rule)
|
||||
}
|
||||
accountCopy.Rules = make(map[string]*Rule)
|
||||
for _, policy := range accountCopy.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
accountCopy.Rules[rule.ID] = rule.ToRule()
|
||||
@@ -276,13 +296,33 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
return s.persist(s.storeFile)
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
|
||||
func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
delete(s.HashedPAT2TokenID, hashedToken)
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
||||
func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
delete(s.TokenID2UserID, tokenID)
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
}
|
||||
|
||||
// GetAccountByPrivateDomain returns account by private domain
|
||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
||||
if !accountIDFound {
|
||||
accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
|
||||
@@ -299,8 +339,8 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !accountIDFound {
|
||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
}
|
||||
|
||||
@@ -312,6 +352,42 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
return account.Copy(), nil
|
||||
}
|
||||
|
||||
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
|
||||
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
tokenID, ok := s.HashedPAT2TokenID[token]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists")
|
||||
}
|
||||
|
||||
return tokenID, nil
|
||||
}
|
||||
|
||||
// GetUserByTokenID returns a User object a tokenID belongs to
|
||||
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
userID, ok := s.TokenID2UserID[tokenID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists")
|
||||
}
|
||||
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
// GetAllAccounts returns all accounts
|
||||
func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||
s.mux.Lock()
|
||||
@@ -325,8 +401,8 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||
|
||||
// getAccount returns a reference to the Account. Should not return a copy.
|
||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||
account, accountFound := s.Accounts[accountID]
|
||||
if !accountFound {
|
||||
account, ok := s.Accounts[accountID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
@@ -351,8 +427,8 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, accountIDFound := s.UserID2AccountID[userID]
|
||||
if !accountIDFound {
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
@@ -369,8 +445,8 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, accountIDFound := s.PeerID2AccountID[peerID]
|
||||
if !accountIDFound {
|
||||
accountID, ok := s.PeerID2AccountID[peerID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID)
|
||||
}
|
||||
|
||||
@@ -395,8 +471,8 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
|
||||
if !accountIDFound {
|
||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
type accounts struct {
|
||||
@@ -71,6 +73,14 @@ func TestNewStore(t *testing.T) {
|
||||
if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
|
||||
t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
|
||||
}
|
||||
|
||||
if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 {
|
||||
t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore")
|
||||
}
|
||||
|
||||
if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 {
|
||||
t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
@@ -85,7 +95,7 @@ func TestSaveAccount(t *testing.T) {
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
// SaveAccount should trigger persist
|
||||
@@ -121,7 +131,7 @@ func TestStore(t *testing.T) {
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Groups["all"] = &Group{
|
||||
ID: "all",
|
||||
@@ -239,11 +249,17 @@ func TestRestore(t *testing.T) {
|
||||
|
||||
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||
|
||||
require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length")
|
||||
|
||||
require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
|
||||
|
||||
require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
|
||||
|
||||
require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
|
||||
|
||||
require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length")
|
||||
|
||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
||||
}
|
||||
|
||||
func TestRestorePolicies_Migration(t *testing.T) {
|
||||
@@ -348,6 +364,137 @@ func TestFileStore_GetAccount(t *testing.T) {
|
||||
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
|
||||
}
|
||||
|
||||
func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
storeFile := filepath.Join(storeDir, "store.json")
|
||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
accounts := &accounts{}
|
||||
_, err = util.ReadJson(storeFile, accounts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
|
||||
tokenID, err := store.GetTokenIDByHashedToken(hashedToken)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
|
||||
assert.Equal(t, expectedTokenID, tokenID)
|
||||
}
|
||||
|
||||
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
|
||||
store := newStore(t)
|
||||
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
|
||||
|
||||
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"])
|
||||
}
|
||||
|
||||
func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) {
|
||||
store := newStore(t)
|
||||
store.TokenID2UserID["someTokenId"] = "someUserId"
|
||||
|
||||
err := store.DeleteTokenID2UserIDIndex("someTokenId")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Empty(t, store.TokenID2UserID["someTokenId"])
|
||||
}
|
||||
|
||||
func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
storeFile := filepath.Join(storeDir, "store.json")
|
||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
accounts := &accounts{}
|
||||
_, err = util.ReadJson(storeFile, accounts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
|
||||
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:]))
|
||||
|
||||
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
|
||||
}
|
||||
|
||||
func TestFileStore_GetUserByTokenID(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
storeFile := filepath.Join(storeDir, "store.json")
|
||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
accounts := &accounts{}
|
||||
_, err = util.ReadJson(storeFile, accounts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
|
||||
user, err := store.GetUserByTokenID(tokenID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
|
||||
}
|
||||
|
||||
func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
storeFile := filepath.Join(storeDir, "store.json")
|
||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
accounts := &accounts{}
|
||||
_, err = util.ReadJson(storeFile, accounts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrongTokenID := "someNonExistingTokenID"
|
||||
_, err = store.GetUserByTokenID(wrongTokenID)
|
||||
|
||||
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
|
||||
}
|
||||
|
||||
func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
@@ -367,7 +514,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := PeerStatus{Connected: true, LastSeen: time.Now()}
|
||||
newStatus := PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -379,7 +526,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
|
||||
@@ -3,24 +3,25 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
gRPCPeer "google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -31,7 +32,7 @@ type GRPCServer struct {
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
config *Config
|
||||
turnCredentialsManager TURNCredentialsManager
|
||||
jwtMiddleware *middleware.JWTMiddleware
|
||||
jwtValidator *jwtclaims.JWTValidator
|
||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
@@ -45,12 +46,12 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var jwtMiddleware *middleware.JWTMiddleware
|
||||
var jwtValidator *jwtclaims.JWTValidator
|
||||
|
||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||
jwtMiddleware, err = middleware.NewJwtMiddleware(
|
||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.HttpConfig.AuthAudience,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||
@@ -86,7 +87,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
accountManager: accountManager,
|
||||
config: config,
|
||||
turnCredentialsManager: turnCredentialsManager,
|
||||
jwtMiddleware: jwtMiddleware,
|
||||
jwtValidator: jwtValidator,
|
||||
jwtClaimsExtractor: jwtClaimsExtractor,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
@@ -187,11 +188,11 @@ func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
if s.jwtMiddleware == nil {
|
||||
return "", status.Error(codes.Internal, "no jwt middleware set")
|
||||
if s.jwtValidator == nil {
|
||||
return "", status.Error(codes.Internal, "no jwt validator set")
|
||||
}
|
||||
|
||||
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken)
|
||||
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
@@ -222,6 +223,7 @@ func mapError(err error) error {
|
||||
default:
|
||||
}
|
||||
}
|
||||
log.Errorf("got an unhandled error: %s", err)
|
||||
return status.Errorf(codes.Internal, "failed handling request")
|
||||
}
|
||||
|
||||
@@ -420,8 +422,6 @@ func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials
|
||||
|
||||
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
|
||||
return &proto.SyncResponse{
|
||||
WiretrusteeConfig: wtConfig,
|
||||
PeerConfig: pConfig,
|
||||
@@ -435,7 +435,6 @@ func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
FirewallRules: firewallRules,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -511,6 +510,8 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
|
||||
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
|
||||
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
|
||||
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ info:
|
||||
tags:
|
||||
- name: Users
|
||||
description: Interact with and view information about users.
|
||||
- name: Tokens
|
||||
description: Interact with and view information about tokens.
|
||||
- name: Peers
|
||||
description: Interact with and view information about peers.
|
||||
- name: Setup Keys
|
||||
@@ -284,6 +286,59 @@ components:
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
PersonalAccessToken:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: ID of a token
|
||||
type: string
|
||||
name:
|
||||
description: Name of the token
|
||||
type: string
|
||||
expiration_date:
|
||||
description: Date the token expires
|
||||
type: string
|
||||
format: date-time
|
||||
created_by:
|
||||
description: User ID of the user who created the token
|
||||
type: string
|
||||
created_at:
|
||||
description: Date the token was created
|
||||
type: string
|
||||
format: date-time
|
||||
last_used:
|
||||
description: Date the token was last used
|
||||
type: string
|
||||
format: date-time
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- expiration_date
|
||||
- created_by
|
||||
- created_at
|
||||
PersonalAccessTokenGenerated:
|
||||
type: object
|
||||
properties:
|
||||
plain_token:
|
||||
description: Plain text representation of the generated token
|
||||
type: string
|
||||
personal_access_token:
|
||||
$ref: '#/components/schemas/PersonalAccessToken'
|
||||
required:
|
||||
- plain_token
|
||||
- personal_access_token
|
||||
PersonalAccessTokenRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Name of the token
|
||||
type: string
|
||||
expires_in:
|
||||
description: Expiration in days
|
||||
type: integer
|
||||
required:
|
||||
- name
|
||||
- expires_in
|
||||
GroupMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -848,6 +903,133 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/tokens:
|
||||
get:
|
||||
summary: Returns a list of all tokens for a user
|
||||
tags: [ Tokens ]
|
||||
security:
|
||||
- BearerAuth: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of PersonalAccessTokens
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/PersonalAccessToken'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a new token
|
||||
tags: [ Tokens ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
requestBody:
|
||||
description: PersonalAccessToken create parameters
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PersonalAccessTokenRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: The token in plain text
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PersonalAccessTokenGenerated'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/tokens/{tokenId}:
|
||||
get:
|
||||
summary: Returns a specific token
|
||||
tags: [ Tokens ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
- in: path
|
||||
name: tokenId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The Token ID
|
||||
responses:
|
||||
'200':
|
||||
description: A PersonalAccessTokens Object
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/PersonalAccessToken'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
delete:
|
||||
summary: Delete a token
|
||||
tags: [ Tokens ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The User ID
|
||||
- in: path
|
||||
name: tokenId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The Token ID
|
||||
responses:
|
||||
'200':
|
||||
description: Delete status code
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/peers:
|
||||
get:
|
||||
summary: Returns a list of all peers
|
||||
|
||||
@@ -379,6 +379,44 @@ type PeerMinimum struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// PersonalAccessToken defines model for PersonalAccessToken.
|
||||
type PersonalAccessToken struct {
|
||||
// CreatedAt Date the token was created
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// CreatedBy User ID of the user who created the token
|
||||
CreatedBy string `json:"created_by"`
|
||||
|
||||
// ExpirationDate Date the token expires
|
||||
ExpirationDate time.Time `json:"expiration_date"`
|
||||
|
||||
// Id ID of a token
|
||||
Id string `json:"id"`
|
||||
|
||||
// LastUsed Date the token was last used
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
|
||||
// Name Name of the token
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// PersonalAccessTokenGenerated defines model for PersonalAccessTokenGenerated.
|
||||
type PersonalAccessTokenGenerated struct {
|
||||
PersonalAccessToken PersonalAccessToken `json:"personal_access_token"`
|
||||
|
||||
// PlainToken Plain text representation of the generated token
|
||||
PlainToken string `json:"plain_token"`
|
||||
}
|
||||
|
||||
// PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest.
|
||||
type PersonalAccessTokenRequest struct {
|
||||
// ExpiresIn Expiration in days
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Name of the token
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Policy defines model for Policy.
|
||||
type Policy struct {
|
||||
// Description Policy friendly description
|
||||
@@ -808,3 +846,6 @@ type PostApiUsersJSONRequestBody = UserCreateRequest
|
||||
|
||||
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
|
||||
type PutApiUsersIdJSONRequestBody = UserRequest
|
||||
|
||||
// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType.
|
||||
type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest
|
||||
|
||||
@@ -54,7 +54,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
ID := uint64(1)
|
||||
events := make([]*activity.Event, 0)
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -64,7 +64,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.UserJoined,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -74,7 +74,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.GroupCreated,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -84,7 +84,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.SetupKeyUpdated,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -94,7 +94,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.SetupKeyUpdated,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -104,7 +104,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.SetupKeyRevoked,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -114,7 +114,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.SetupKeyOverused,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -124,7 +124,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.SetupKeyCreated,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -134,7 +134,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.RuleAdded,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -144,7 +144,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.RuleRemoved,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -154,7 +154,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.RuleUpdated,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
@@ -164,7 +164,7 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
})
|
||||
ID++
|
||||
events = append(events, &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedWithSetupKey,
|
||||
ID: ID,
|
||||
InitiatorID: userID,
|
||||
|
||||
@@ -300,7 +300,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, "")
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetGroup returns a group
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
@@ -25,15 +26,17 @@ type apiHandler struct {
|
||||
AuthCfg AuthCfg
|
||||
}
|
||||
|
||||
// EmptyObject is an empty struct used to return empty JSON object
|
||||
type emptyObject struct {
|
||||
}
|
||||
|
||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
||||
jwtMiddleware, err := middleware.NewJwtMiddleware(
|
||||
authCfg.Issuer,
|
||||
authCfg.Audience,
|
||||
authCfg.KeysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
authCfg.Audience)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
@@ -46,7 +49,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
router := rootRouter.PathPrefix("/api").Subrouter()
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, jwtMiddleware.Handler, acMiddleware.Handler)
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
|
||||
|
||||
api := apiHandler{
|
||||
Router: router,
|
||||
@@ -57,6 +60,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
|
||||
api.addAccountsEndpoint()
|
||||
api.addPeersEndpoint()
|
||||
api.addUsersEndpoint()
|
||||
api.addUsersTokensEndpoint()
|
||||
api.addSetupKeysEndpoint()
|
||||
api.addRulesEndpoint()
|
||||
api.addPoliciesEndpoint()
|
||||
@@ -66,7 +70,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
|
||||
api.addDNSSettingEndpoint()
|
||||
api.addEventsEndpoint()
|
||||
|
||||
err = api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
|
||||
err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
|
||||
methods, err := route.GetMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -110,6 +114,14 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
|
||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
||||
tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (apiHandler *apiHandler) addSetupKeysEndpoint() {
|
||||
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
|
||||
@@ -2,6 +2,9 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -39,10 +42,22 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
switch r.Method {
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
|
||||
ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path)
|
||||
if err != nil {
|
||||
log.Debugf("Regex failed")
|
||||
util.WriteError(status.Errorf(status.Internal, ""), w)
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
log.Debugf("Valid Path")
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
|
||||
return
|
||||
}
|
||||
|
||||
173
management/server/http/middleware/auth_middleware.go
Normal file
173
management/server/http/middleware/auth_middleware.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
||||
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(token string) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
audience string
|
||||
}
|
||||
|
||||
const (
|
||||
userProperty = "user"
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
audience: audience,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
|
||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
||||
authType := auth[0]
|
||||
switch strings.ToLower(authType) {
|
||||
case "bearer":
|
||||
err := m.CheckJWTFromRequest(w, r)
|
||||
if err != nil {
|
||||
log.Debugf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
err := m.CheckPATFromRequest(w, r)
|
||||
if err != nil {
|
||||
log.Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
token, err := getTokenFromJWTRequest(r)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if validatedToken == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we get here, everything worked and we can set the
|
||||
// user property in context.
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) CheckPATFromRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
token, err := getTokenFromPATRequest(r)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.ExpirationDate) {
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.markPATUsed(pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[jwtclaims.UserIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes a give request and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", nil // No error, just no token
|
||||
}
|
||||
|
||||
// TODO: Make this a bit more robust, parsing-wise
|
||||
authHeaderParts := strings.Fields(authHeader)
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("Authorization header format must be Bearer {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
|
||||
// getTokenFromPATRequest is a "TokenExtractor" that takes a give request and extracts
|
||||
// the PAT token from the Authorization header.
|
||||
func getTokenFromPATRequest(r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", nil // No error, just no token
|
||||
}
|
||||
|
||||
// TODO: Make this a bit more robust, parsing-wise
|
||||
authHeaderParts := strings.Fields(authHeader)
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "token" {
|
||||
return "", errors.New("Authorization header format must be Token {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
123
management/server/http/middleware/auth_middleware_test.go
Normal file
123
management/server/http/middleware/auth_middleware_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
)
|
||||
|
||||
const (
|
||||
audience = "audience"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
Id: accountID,
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
Name: "My first token",
|
||||
HashedToken: "someHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: userID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(token string) error {
|
||||
if token == tokenID {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "Valid PAT Token",
|
||||
authHeader: "Token " + PAT,
|
||||
expectedStatusCode: 200,
|
||||
},
|
||||
{
|
||||
name: "Invalid PAT Token",
|
||||
authHeader: "Token " + wrongToken,
|
||||
expectedStatusCode: 401,
|
||||
},
|
||||
{
|
||||
name: "Valid JWT Token",
|
||||
authHeader: "Bearer " + JWT,
|
||||
expectedStatusCode: 200,
|
||||
},
|
||||
{
|
||||
name: "Invalid JWT Token",
|
||||
authHeader: "Bearer " + wrongToken,
|
||||
expectedStatusCode: 401,
|
||||
},
|
||||
{
|
||||
name: "Basic Auth",
|
||||
authHeader: "Basic " + PAT,
|
||||
expectedStatusCode: 401,
|
||||
},
|
||||
}
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// do nothing
|
||||
})
|
||||
|
||||
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://testing", nil)
|
||||
req.Header.Set("Authorization", tc.authHeader)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Result().StatusCode != tc.expectedStatusCode {
|
||||
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A function called whenever an error is encountered
|
||||
type errorHandler func(w http.ResponseWriter, r *http.Request, err string)
|
||||
|
||||
// TokenExtractor is a function that takes a request as input and returns
|
||||
// either a token or an error. An error should only be returned if an attempt
|
||||
// to specify a token was found, but the information was somehow incorrectly
|
||||
// formed. In the case where a token is simply not present, this should not
|
||||
// be treated as an error. An empty string should be returned in that case.
|
||||
type TokenExtractor func(r *http.Request) (string, error)
|
||||
|
||||
// Options is a struct for specifying configuration options for the middleware.
|
||||
type Options struct {
|
||||
// The function that will return the Key to validate the JWT.
|
||||
// It can be either a shared secret or a public key.
|
||||
// Default value: nil
|
||||
ValidationKeyGetter jwt.Keyfunc
|
||||
// The name of the property in the request where the user information
|
||||
// from the JWT will be stored.
|
||||
// Default value: "user"
|
||||
UserProperty string
|
||||
// The function that will be called when there's an error validating the token
|
||||
// Default value:
|
||||
ErrorHandler errorHandler
|
||||
// A boolean indicating if the credentials are required or not
|
||||
// Default value: false
|
||||
CredentialsOptional bool
|
||||
// A function that extracts the token from the request
|
||||
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
|
||||
Extractor TokenExtractor
|
||||
// Debug flag turns on debugging output
|
||||
// Default: false
|
||||
Debug bool
|
||||
// When set, all requests with the OPTIONS method will use authentication
|
||||
// Default: false
|
||||
EnableAuthOnOptions bool
|
||||
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
|
||||
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
|
||||
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
||||
// Default: nil
|
||||
SigningMethod jwt.SigningMethod
|
||||
}
|
||||
|
||||
type JWTMiddleware struct {
|
||||
Options Options
|
||||
}
|
||||
|
||||
func OnError(w http.ResponseWriter, r *http.Request, err string) {
|
||||
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
|
||||
}
|
||||
|
||||
// New constructs a new Secure instance with supplied options.
|
||||
func New(options ...Options) *JWTMiddleware {
|
||||
|
||||
var opts Options
|
||||
if len(options) == 0 {
|
||||
opts = Options{}
|
||||
} else {
|
||||
opts = options[0]
|
||||
}
|
||||
|
||||
if opts.UserProperty == "" {
|
||||
opts.UserProperty = "user"
|
||||
}
|
||||
|
||||
if opts.ErrorHandler == nil {
|
||||
opts.ErrorHandler = OnError
|
||||
}
|
||||
|
||||
if opts.Extractor == nil {
|
||||
opts.Extractor = FromAuthHeader
|
||||
}
|
||||
|
||||
return &JWTMiddleware{
|
||||
Options: opts,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *JWTMiddleware) logf(format string, args ...interface{}) {
|
||||
if m.Options.Debug {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere.
|
||||
func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
err := m.CheckJWTFromRequest(w, r)
|
||||
|
||||
// If there was an error, do not call next.
|
||||
if err == nil && next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *JWTMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Let secure process the request. If it returns an error,
|
||||
// that indicates the request should not continue.
|
||||
err := m.CheckJWTFromRequest(w, r)
|
||||
|
||||
// If there was an error, do not continue.
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func FromAuthHeader(r *http.Request) (string, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", nil // No error, just no token
|
||||
}
|
||||
|
||||
// TODO: Make this a bit more robust, parsing-wise
|
||||
authHeaderParts := strings.Fields(authHeader)
|
||||
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
|
||||
return "", errors.New("Authorization header format must be Bearer {token}")
|
||||
}
|
||||
|
||||
return authHeaderParts[1], nil
|
||||
}
|
||||
|
||||
// FromParameter returns a function that extracts the token from the specified
|
||||
// query string parameter
|
||||
func FromParameter(param string) TokenExtractor {
|
||||
return func(r *http.Request) (string, error) {
|
||||
return r.URL.Query().Get(param), nil
|
||||
}
|
||||
}
|
||||
|
||||
// FromFirst returns a function that runs multiple token extractors and takes the
|
||||
// first token it finds
|
||||
func FromFirst(extractors ...TokenExtractor) TokenExtractor {
|
||||
return func(r *http.Request) (string, error) {
|
||||
for _, ex := range extractors {
|
||||
token, err := ex(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *JWTMiddleware) CheckJWTFromRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
if !m.Options.EnableAuthOnOptions {
|
||||
if r.Method == "OPTIONS" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Use the specified token extractor to extract a token from the request
|
||||
token, err := m.Options.Extractor(r)
|
||||
|
||||
// If debugging is turned on, log the outcome
|
||||
if err != nil {
|
||||
m.logf("Error extracting JWT: %v", err)
|
||||
} else {
|
||||
m.logf("Token extracted: %s", token)
|
||||
}
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
m.Options.ErrorHandler(w, r, err.Error())
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.ValidateAndParse(token)
|
||||
if err != nil {
|
||||
m.Options.ErrorHandler(w, r, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if validatedToken == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we get here, everything worked and we can set the
|
||||
// user property in context.
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validatedToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
*r = *newRequest
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAndParse validates and parses a given access token against jwt standards and signing methods
|
||||
func (m *JWTMiddleware) ValidateAndParse(token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// Check if it was required
|
||||
if m.Options.CredentialsOptional {
|
||||
m.logf("no credentials found (CredentialsOptional=true)")
|
||||
// No error, just no token (and that is ok given that CredentialsOptional is true)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we get here, the required token is missing
|
||||
errorMsg := "required authorization token not found"
|
||||
m.logf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, fmt.Errorf(errorMsg)
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter)
|
||||
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
m.logf("error parsing token: %v", err)
|
||||
return nil, fmt.Errorf("Error parsing token: %w", err)
|
||||
}
|
||||
|
||||
if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] {
|
||||
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
|
||||
m.Options.SigningMethod.Alg(),
|
||||
parsedToken.Header["alg"])
|
||||
m.logf("error validating token algorithm: %s", errorMsg)
|
||||
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
|
||||
}
|
||||
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
errorMsg := "token is invalid"
|
||||
m.logf(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
return parsedToken, nil
|
||||
}
|
||||
@@ -243,7 +243,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, "")
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||
|
||||
178
management/server/http/pat_handler.go
Normal file
178
management/server/http/pat_handler.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// PATHandler is the nameserver group handler of the account
|
||||
type PATHandler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// NewPATsHandler creates a new PATHandler HTTP handler
|
||||
func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler {
|
||||
return &PATHandler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var patResponse []*api.PersonalAccessToken
|
||||
for _, pat := range pats {
|
||||
patResponse = append(patResponse, toPATResponse(pat))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, patResponse)
|
||||
}
|
||||
|
||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPATResponse(pat))
|
||||
}
|
||||
|
||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPATGeneratedResponse(pat))
|
||||
}
|
||||
|
||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
var lastUsed *time.Time
|
||||
if !pat.LastUsed.IsZero() {
|
||||
lastUsed = &pat.LastUsed
|
||||
}
|
||||
return &api.PersonalAccessToken{
|
||||
CreatedAt: pat.CreatedAt,
|
||||
CreatedBy: pat.CreatedBy,
|
||||
Name: pat.Name,
|
||||
ExpirationDate: pat.ExpirationDate,
|
||||
Id: pat.ID,
|
||||
LastUsed: lastUsed,
|
||||
}
|
||||
}
|
||||
|
||||
func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
|
||||
return &api.PersonalAccessTokenGenerated{
|
||||
PlainToken: pat.PlainToken,
|
||||
PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken),
|
||||
}
|
||||
}
|
||||
254
management/server/http/pat_handler_test.go
Normal file
254
management/server/http/pat_handler_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
existingAccountID = "existingAccountID"
|
||||
notFoundAccountID = "notFoundAccountID"
|
||||
existingUserID = "existingUserID"
|
||||
notFoundUserID = "notFoundUserID"
|
||||
existingTokenID = "existingTokenID"
|
||||
notFoundTokenID = "notFoundTokenID"
|
||||
domain = "hotmail.com"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
existingTokenID: {
|
||||
ID: existingTokenID,
|
||||
Name: "My first token",
|
||||
HashedToken: "someHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: existingUserID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
"token2": {
|
||||
ID: "token2",
|
||||
Name: "My second token",
|
||||
HashedToken: "someOtherHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: existingUserID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func initPATTestData() *PATHandler {
|
||||
return &PATHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return &server.PersonalAccessTokenGenerated{
|
||||
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
|
||||
PersonalAccessToken: server.PersonalAccessToken{},
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
},
|
||||
DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
if tokenID != existingTokenID {
|
||||
return status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
if tokenID != existingTokenID {
|
||||
return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: domain,
|
||||
AccountId: testNSGroupAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandlers(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "Get All Tokens",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Existing Token",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Not Existing Token",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Delete Existing Token",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Not Existing Token",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "POST OK",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"expires_in\":7}")),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
p := initPATTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
if !tc.expectedBody {
|
||||
return
|
||||
}
|
||||
|
||||
switch tc.name {
|
||||
case "POST OK":
|
||||
got := &api.PersonalAccessTokenGenerated{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.NotEmpty(t, got.PlainToken)
|
||||
assert.Equal(t, server.PATLength, len(got.PlainToken))
|
||||
case "Get All Tokens":
|
||||
expectedTokens := []api.PersonalAccessToken{
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]),
|
||||
}
|
||||
|
||||
var got []api.PersonalAccessToken
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.True(t, cmp.Equal(got, expectedTokens))
|
||||
case "Get Existing Token":
|
||||
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
|
||||
got := &api.PersonalAccessToken{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.True(t, cmp.Equal(*got, expectedToken))
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
|
||||
return api.PersonalAccessToken{
|
||||
Id: serverToken.ID,
|
||||
Name: serverToken.Name,
|
||||
CreatedAt: serverToken.CreatedAt,
|
||||
LastUsed: &serverToken.LastUsed,
|
||||
CreatedBy: serverToken.CreatedBy,
|
||||
ExpirationDate: serverToken.ExpirationDate,
|
||||
}
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w htt
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, "")
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
|
||||
@@ -225,7 +225,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, "")
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetPolicy handles a group Get request identified by ID
|
||||
|
||||
@@ -321,7 +321,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, "")
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetRoute handles a route Get request identified by ID
|
||||
|
||||
@@ -222,7 +222,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, "")
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetRule handles a group Get request identified by ID
|
||||
|
||||
@@ -4,10 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
||||
@@ -93,9 +96,11 @@ func WriteError(err error, w http.ResponseWriter) {
|
||||
httpStatus = http.StatusInternalServerError
|
||||
case status.InvalidArgument:
|
||||
httpStatus = http.StatusUnprocessableEntity
|
||||
case status.Unauthorized:
|
||||
httpStatus = http.StatusUnauthorized
|
||||
default:
|
||||
}
|
||||
msg = err.Error()
|
||||
msg = strings.ToLower(err.Error())
|
||||
} else {
|
||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
|
||||
log.Error(unhandledMSG)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -3,14 +3,16 @@ package idp
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -251,7 +253,7 @@ func TestAuth0_Authenticate(t *testing.T) {
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
//expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
// expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package idp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// Manager idp manager interface
|
||||
@@ -20,8 +21,9 @@ type Manager interface {
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
type Config struct {
|
||||
ManagerType string
|
||||
Auth0ClientCredentials Auth0ClientConfig
|
||||
ManagerType string
|
||||
Auth0ClientCredentials Auth0ClientConfig
|
||||
KeycloakClientCredentials KeycloakClientConfig
|
||||
}
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
@@ -71,6 +73,8 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
||||
return nil, nil
|
||||
case "auth0":
|
||||
return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics)
|
||||
case "keycloak":
|
||||
return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
}
|
||||
|
||||
582
management/server/idp/keycloak.go
Normal file
582
management/server/idp/keycloak.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
wtAccountID = "wt_account_id"
|
||||
wtPendingInvite = "wt_pending_invite"
|
||||
)
|
||||
|
||||
// KeycloakManager keycloak manager client instance.
|
||||
type KeycloakManager struct {
|
||||
adminEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// KeycloakClientConfig keycloak manager client configurations.
|
||||
type KeycloakClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AdminEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// KeycloakCredentials keycloak authentication information.
|
||||
type KeycloakCredentials struct {
|
||||
clientConfig KeycloakClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// keycloakUserCredential describe the authentication method for,
|
||||
// newly created user profile.
|
||||
type keycloakUserCredential struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
Temporary bool `json:"temporary"`
|
||||
}
|
||||
|
||||
// keycloakUserAttributes holds additional user data fields.
|
||||
type keycloakUserAttributes map[string][]string
|
||||
|
||||
// createUserRequest is a user create request.
|
||||
type keycloakCreateUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Enabled bool `json:"enabled"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Credentials []keycloakUserCredential `json:"credentials"`
|
||||
Attributes keycloakUserAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
// keycloakProfile represents an keycloak user profile response.
|
||||
type keycloakProfile struct {
|
||||
ID string `json:"id"`
|
||||
CreatedTimestamp int64 `json:"createdTimestamp"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Attributes keycloakUserAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak idp configuration is not complete")
|
||||
}
|
||||
|
||||
if config.GrantType != "client_credentials" {
|
||||
return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials")
|
||||
}
|
||||
|
||||
credentials := &KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &KeycloakManager{
|
||||
adminEndpoint: config.AdminEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
|
||||
func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kc.clientConfig.ClientID)
|
||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", kc.clientConfig.GrantType)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for keycloak idp manager")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = kc.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = kc.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the keycloak Management API.
|
||||
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
kc.mux.Lock()
|
||||
defer kc.mux.Unlock()
|
||||
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if kc.jwtStillValid() {
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := kc.requestJWTToken()
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
|
||||
kc.jwtToken = jwtToken
|
||||
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(email string, name string, accountID string) (*UserData, error) {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
invite := true
|
||||
appMetadata := AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
}
|
||||
|
||||
payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users", km.adminEndpoint)
|
||||
payload := strings.NewReader(payloadString)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, reqURL, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
locationHeader := resp.Header.Get("location")
|
||||
userID, err := extractUserIDFromLocationHeader(locationHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return km.GetUserDataByID(userID, appMetadata)
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("email", email)
|
||||
q.Add("exact", "true")
|
||||
|
||||
body, err := km.get("users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := km.get("users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile keycloakProfile
|
||||
err = km.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profile.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("q", wtAccountID+":"+accountID)
|
||||
|
||||
body, err := km.get("users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
totalUsers, err := km.totalUsersCount()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get("users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
|
||||
accountID := userData.AppMetadata.WTAccountID
|
||||
if accountID != "" {
|
||||
if _, ok := indexedUsers[accountID]; !ok {
|
||||
indexedUsers[accountID] = make([]*UserData, 0)
|
||||
}
|
||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||
}
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
attrs := keycloakUserAttributes{}
|
||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||
if appMetadata.WTPendingInvite != nil {
|
||||
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
|
||||
} else {
|
||||
attrs.Set(wtPendingInvite, "false")
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID)
|
||||
data, err := km.helper.Marshal(map[string]any{
|
||||
"attributes": attrs,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, reqURL, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||
attrs := keycloakUserAttributes{}
|
||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
|
||||
|
||||
req := &keycloakCreateUserRequest{
|
||||
Email: email,
|
||||
Username: name,
|
||||
Enabled: true,
|
||||
EmailVerified: true,
|
||||
Credentials: []keycloakUserCredential{
|
||||
{
|
||||
Type: "password",
|
||||
Value: GeneratePassword(8, 1, 1, 1),
|
||||
Temporary: false,
|
||||
},
|
||||
},
|
||||
Attributes: attrs,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount() (*int, error) {
|
||||
body, err := km.get("users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(string(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &count, nil
|
||||
}
|
||||
|
||||
// extractUserIDFromLocationHeader extracts the user ID from the location,
|
||||
// header once the user is created successfully
|
||||
func extractUserIDFromLocationHeader(locationHeader string) (string, error) {
|
||||
userURL, err := url.Parse(locationHeader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return path.Base(userURL.Path), nil
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (kp keycloakProfile) userData() *UserData {
|
||||
accountID := kp.Attributes.Get(wtAccountID)
|
||||
pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite))
|
||||
if err != nil {
|
||||
pendingInvite = false
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: kp.Email,
|
||||
Name: kp.Username,
|
||||
ID: kp.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pendingInvite,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the key to value. It replaces any existing
|
||||
// values.
|
||||
func (ka keycloakUserAttributes) Set(key, value string) {
|
||||
ka[key] = []string{value}
|
||||
}
|
||||
|
||||
// Get returns the first value associated with the given key.
|
||||
// If there are no values associated with the key, Get returns
|
||||
// the empty string.
|
||||
func (ka keycloakUserAttributes) Get(key string) string {
|
||||
if ka == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
values := ka[key]
|
||||
if len(values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return values[0]
|
||||
}
|
||||
402
management/server/idp/keycloak_test.go
Normal file
402
management/server/idp/keycloak_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewKeycloakManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig KeycloakClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := KeycloakClientConfig{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
|
||||
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = "authorization_code"
|
||||
|
||||
testCase5 := test{
|
||||
name: "Wrong GrantType",
|
||||
inputConfig: testCase5Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when wrong grant type",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockKeycloakCredentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakUpdateUserAppMetadata(t *testing.T) {
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 204,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID),
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 204,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockKeycloakCredentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
reqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
|
||||
manager := &KeycloakManager{
|
||||
httpClient: &reqClient,
|
||||
credentials: testCase.managerCreds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,14 +7,19 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TokenUserProperty = "user"
|
||||
AccountIDSuffix = "wt_account_id"
|
||||
DomainIDSuffix = "wt_account_domain"
|
||||
// TokenUserProperty key for the user property in the request context
|
||||
TokenUserProperty = "user"
|
||||
// AccountIDSuffix suffix for the account id claim
|
||||
AccountIDSuffix = "wt_account_id"
|
||||
// DomainIDSuffix suffix for the domain id claim
|
||||
DomainIDSuffix = "wt_account_domain"
|
||||
// DomainCategorySuffix suffix for the domain category claim
|
||||
DomainCategorySuffix = "wt_account_domain_category"
|
||||
UserIDClaim = "sub"
|
||||
// UserIDClaim claim for the user id
|
||||
UserIDClaim = "sub"
|
||||
)
|
||||
|
||||
// Extract function type
|
||||
// ExtractClaims Extract function type
|
||||
type ExtractClaims func(r *http.Request) AuthorizationClaims
|
||||
|
||||
// ClaimsExtractor struct that holds the extract function
|
||||
|
||||
@@ -26,7 +26,7 @@ func newTestRequestWithJWT(t *testing.T, claims AuthorizationClaims, audiance st
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
r, err := http.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
require.NoError(t, err, "creating testing request failed")
|
||||
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) //nolint
|
||||
testRequest := r.WithContext(context.WithValue(r.Context(), TokenUserProperty, token)) // nolint
|
||||
|
||||
return testRequest
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package middleware
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -17,6 +17,32 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Options is a struct for specifying configuration options for the middleware.
|
||||
type Options struct {
|
||||
// The function that will return the Key to validate the JWT.
|
||||
// It can be either a shared secret or a public key.
|
||||
// Default value: nil
|
||||
ValidationKeyGetter jwt.Keyfunc
|
||||
// The name of the property in the request where the user information
|
||||
// from the JWT will be stored.
|
||||
// Default value: "user"
|
||||
UserProperty string
|
||||
// The function that will be called when there's an error validating the token
|
||||
// Default value:
|
||||
CredentialsOptional bool
|
||||
// A function that extracts the token from the request
|
||||
// Default: FromAuthHeader (i.e., from Authorization header as bearer token)
|
||||
Debug bool
|
||||
// When set, all requests with the OPTIONS method will use authentication
|
||||
// Default: false
|
||||
EnableAuthOnOptions bool
|
||||
// When set, the middelware verifies that tokens are signed with the specific signing algorithm
|
||||
// If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
|
||||
// Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
|
||||
// Default: nil
|
||||
SigningMethod jwt.SigningMethod
|
||||
}
|
||||
|
||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKey `json:"keys"`
|
||||
@@ -32,17 +58,28 @@ type JSONWebKey struct {
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
|
||||
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
|
||||
// JWTValidator struct to handle token validation and parsing
|
||||
type JWTValidator struct {
|
||||
options Options
|
||||
}
|
||||
|
||||
// NewJWTValidator constructor
|
||||
func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) {
|
||||
keys, err := getPemKeys(keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return New(Options{
|
||||
options := Options{
|
||||
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify 'aud' claim
|
||||
checkAud := token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
|
||||
var checkAud bool
|
||||
for _, audience := range audienceList {
|
||||
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
|
||||
if checkAud {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !checkAud {
|
||||
return token, errors.New("invalid audience")
|
||||
}
|
||||
@@ -62,7 +99,59 @@ func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWT
|
||||
},
|
||||
SigningMethod: jwt.SigningMethodRS256,
|
||||
EnableAuthOnOptions: false,
|
||||
}), nil
|
||||
}
|
||||
|
||||
if options.UserProperty == "" {
|
||||
options.UserProperty = "user"
|
||||
}
|
||||
|
||||
return &JWTValidator{
|
||||
options: options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// Check if it was required
|
||||
if m.options.CredentialsOptional {
|
||||
log.Debugf("no credentials found (CredentialsOptional=true)")
|
||||
// No error, just no token (and that is ok given that CredentialsOptional is true)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If we get here, the required token is missing
|
||||
errorMsg := "required authorization token not found"
|
||||
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, fmt.Errorf(errorMsg)
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
parsedToken, err := jwt.Parse(token, m.options.ValidationKeyGetter)
|
||||
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
log.Debugf("error parsing token: %v", err)
|
||||
return nil, fmt.Errorf("Error parsing token: %w", err)
|
||||
}
|
||||
|
||||
if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
|
||||
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
|
||||
m.options.SigningMethod.Alg(),
|
||||
parsedToken.Header["alg"])
|
||||
log.Debugf("error validating token algorithm: %s", errorMsg)
|
||||
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
|
||||
}
|
||||
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
errorMsg := "token is invalid"
|
||||
log.Debugf(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
return parsedToken, nil
|
||||
}
|
||||
|
||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
@@ -12,20 +12,23 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
server "github.com/netbirdio/netbird/management/server"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -47,6 +47,8 @@ type MockAccountManager struct {
|
||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(pat string) error
|
||||
UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error)
|
||||
@@ -59,6 +61,10 @@ type MockAccountManager struct {
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
@@ -175,6 +181,54 @@ func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*ser
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(pat)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
}
|
||||
|
||||
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPATUsed(pat string) error {
|
||||
if am.MarkPATUsedFunc != nil {
|
||||
return am.MarkPATUsedFunc(pat)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
|
||||
}
|
||||
|
||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if am.CreatePATFunc != nil {
|
||||
return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
||||
}
|
||||
|
||||
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
if am.DeletePATFunc != nil {
|
||||
return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
||||
}
|
||||
|
||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if am.GetPATFunc != nil {
|
||||
return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
||||
}
|
||||
|
||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if am.GetAllPATsFunc != nil {
|
||||
return am.GetAllPATsFunc(accountID, executingUserID, targetUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) {
|
||||
if am.GetNetworkMapFunc != nil {
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/c-robinson/iplib"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/c-robinson/iplib"
|
||||
"github.com/rs/xid"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,12 +25,11 @@ const (
|
||||
)
|
||||
|
||||
type NetworkMap struct {
|
||||
Peers []*Peer
|
||||
Network *Network
|
||||
Routes []*route.Route
|
||||
DNSConfig nbdns.Config
|
||||
OfflinePeers []*Peer
|
||||
FirewallRules []*FirewallRule
|
||||
Peers []*Peer
|
||||
Network *Network
|
||||
Routes []*route.Route
|
||||
DNSConfig nbdns.Config
|
||||
OfflinePeers []*Peer
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
|
||||
@@ -6,9 +6,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -208,7 +209,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
// TODO: use firewall rules
|
||||
aclPeers, _ := account.getPeersByPolicy(peer.ID)
|
||||
aclPeers := account.getPeersByACL(peer.ID)
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@@ -245,7 +246,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now()
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
newStatus.Connected = connected
|
||||
// whenever peer got connected that means that it logged in successfully
|
||||
if newStatus.Connected {
|
||||
@@ -477,7 +478,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
}
|
||||
|
||||
opEvent := &activity.Event{
|
||||
Timestamp: time.Now(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
AccountID: account.Id,
|
||||
}
|
||||
|
||||
@@ -524,11 +525,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
Name: peer.Meta.Hostname,
|
||||
DNSLabel: newLabel,
|
||||
UserID: userID,
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: time.Now(),
|
||||
LoginExpirationEnabled: true,
|
||||
LastLogin: time.Now().UTC(),
|
||||
LoginExpirationEnabled: addedByUser,
|
||||
}
|
||||
|
||||
// add peer to 'All' group
|
||||
@@ -575,7 +576,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain)
|
||||
return newPeer, networkMap, nil
|
||||
}
|
||||
|
||||
@@ -704,7 +705,7 @@ func updatePeerLastLogin(peer *Peer, account *Account) {
|
||||
|
||||
// UpdateLastLogin and set login expired false
|
||||
func (p *Peer) UpdateLastLogin() *Peer {
|
||||
p.LastLogin = time.Now()
|
||||
p.LastLogin = time.Now().UTC()
|
||||
newStatus := p.Status.Copy()
|
||||
newStatus.LoginExpired = false
|
||||
p.Status = newStatus
|
||||
@@ -815,7 +816,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*Pee
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeersByPolicy(p.ID)
|
||||
aclPeers := account.getPeersByACL(p.ID)
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peerID {
|
||||
return peer, nil
|
||||
@@ -832,6 +833,98 @@ func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) *Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
// GetPeerRules returns a list of source or destination rules of a given peer.
|
||||
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
|
||||
// Rules are group based so there is no direct access to peers.
|
||||
// First, find all groups that the given peer belongs to
|
||||
peerGroups := make(map[string]struct{})
|
||||
|
||||
for s, group := range a.Groups {
|
||||
for _, peer := range group.Peers {
|
||||
if peerID == peer {
|
||||
peerGroups[s] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second, find all rules that have discovered source and destination groups
|
||||
srcRulesMap := make(map[string]*Rule)
|
||||
dstRulesMap := make(map[string]*Rule)
|
||||
for _, rule := range a.Rules {
|
||||
for _, g := range rule.Source {
|
||||
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
|
||||
srcRules = append(srcRules, rule)
|
||||
srcRulesMap[rule.ID] = rule
|
||||
}
|
||||
}
|
||||
for _, g := range rule.Destination {
|
||||
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
|
||||
dstRules = append(dstRules, rule)
|
||||
dstRulesMap[rule.ID] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return srcRules, dstRules
|
||||
}
|
||||
|
||||
// getPeersByACL returns all peers that given peer has access to.
|
||||
func (a *Account) getPeersByACL(peerID string) []*Peer {
|
||||
var peers []*Peer
|
||||
srcRules, dstRules := a.GetPeerRules(peerID)
|
||||
|
||||
groups := map[string]*Group{}
|
||||
for _, r := range srcRules {
|
||||
if r.Disabled {
|
||||
continue
|
||||
}
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Destination {
|
||||
if group, ok := a.Groups[gid]; ok {
|
||||
groups[gid] = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range dstRules {
|
||||
if r.Disabled {
|
||||
continue
|
||||
}
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Source {
|
||||
if group, ok := a.Groups[gid]; ok {
|
||||
groups[gid] = group
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersSet := make(map[string]struct{})
|
||||
for _, g := range groups {
|
||||
for _, pid := range g.Peers {
|
||||
peer, ok := a.Peers[pid]
|
||||
if !ok {
|
||||
log.Warnf(
|
||||
"peer %s found in group %s but doesn't belong to account %s",
|
||||
pid,
|
||||
g.ID,
|
||||
a.Id,
|
||||
)
|
||||
continue
|
||||
}
|
||||
// exclude original peer
|
||||
if _, ok := peersSet[peer.ID]; peer.ID != peerID && !ok {
|
||||
peersSet[peer.ID] = struct{}{}
|
||||
peers = append(peers, peer.Copy())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peers
|
||||
}
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
{
|
||||
name: "Peer Login Expiration Disabled. Peer Login Should Not Expire",
|
||||
expirationEnabled: false,
|
||||
lastLogin: time.Now().Add(-25 * time.Hour),
|
||||
lastLogin: time.Now().UTC().Add(-25 * time.Hour),
|
||||
accountSettings: &Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
@@ -31,7 +31,7 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
{
|
||||
name: "Peer Login Should Expire",
|
||||
expirationEnabled: true,
|
||||
lastLogin: time.Now().Add(-25 * time.Hour),
|
||||
lastLogin: time.Now().UTC().Add(-25 * time.Hour),
|
||||
accountSettings: &Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
@@ -41,7 +41,7 @@ func TestPeer_LoginExpired(t *testing.T) {
|
||||
{
|
||||
name: "Peer Login Should Not Expire",
|
||||
expirationEnabled: true,
|
||||
lastLogin: time.Now(),
|
||||
lastLogin: time.Now().UTC(),
|
||||
accountSettings: &Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
@@ -136,6 +136,8 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
// TODO: disable until we start use policy again
|
||||
t.Skip()
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"time"
|
||||
@@ -13,14 +14,19 @@ import (
|
||||
|
||||
const (
|
||||
// PATPrefix is the globally used, 4 char prefix for personal access tokens
|
||||
PATPrefix = "nbp_"
|
||||
secretLength = 30
|
||||
PATPrefix = "nbp_"
|
||||
// PATSecretLength number of characters used for the secret inside the token
|
||||
PATSecretLength = 30
|
||||
// PATChecksumLength number of characters used for the encoded checksum of the secret inside the token
|
||||
PATChecksumLength = 6
|
||||
// PATLength total number of characters used for the token
|
||||
PATLength = 40
|
||||
)
|
||||
|
||||
// PersonalAccessToken holds all information about a PAT including a hashed version of it for verification
|
||||
type PersonalAccessToken struct {
|
||||
ID string
|
||||
Description string
|
||||
Name string
|
||||
HashedToken string
|
||||
ExpirationDate time.Time
|
||||
// scope could be added in future
|
||||
@@ -29,27 +35,37 @@ type PersonalAccessToken struct {
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it
|
||||
type PersonalAccessTokenGenerated struct {
|
||||
PlainToken string
|
||||
PersonalAccessToken
|
||||
}
|
||||
|
||||
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||
func CreateNewPAT(description string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) {
|
||||
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateNewToken()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, err
|
||||
}
|
||||
currentTime := time.Now().UTC()
|
||||
return &PersonalAccessToken{
|
||||
ID: xid.New().String(),
|
||||
Description: description,
|
||||
HashedToken: hashedToken,
|
||||
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: currentTime,
|
||||
LastUsed: currentTime,
|
||||
}, plainToken, nil
|
||||
currentTime := time.Now()
|
||||
return &PersonalAccessTokenGenerated{
|
||||
PersonalAccessToken: PersonalAccessToken{
|
||||
ID: xid.New().String(),
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: currentTime,
|
||||
LastUsed: time.Time{},
|
||||
},
|
||||
PlainToken: plainToken,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func generateNewToken() (string, string, error) {
|
||||
secret, err := b.Random(secretLength)
|
||||
secret, err := b.Random(PATSecretLength)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -59,5 +75,6 @@ func generateNewToken() (string, string, error) {
|
||||
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
|
||||
plainToken := PATPrefix + secret + paddedChecksum
|
||||
hashedToken := sha256.Sum256([]byte(plainToken))
|
||||
return string(hashedToken[:]), plainToken, nil
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
return encodedHashedToken, plainToken, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -13,7 +14,8 @@ import (
|
||||
func TestPAT_GenerateToken_Hashing(t *testing.T) {
|
||||
hashedToken, plainToken, _ := generateNewToken()
|
||||
expectedToken := sha256.Sum256([]byte(plainToken))
|
||||
assert.Equal(t, hashedToken, string(expectedToken[:]))
|
||||
encodedExpectedToken := b64.StdEncoding.EncodeToString(expectedToken[:])
|
||||
assert.Equal(t, hashedToken, encodedExpectedToken)
|
||||
}
|
||||
|
||||
func TestPAT_GenerateToken_Prefix(t *testing.T) {
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"html/template"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
@@ -177,11 +176,11 @@ type FirewallRule struct {
|
||||
// Action of the traffic
|
||||
Action string
|
||||
|
||||
// Protocol of the traffic
|
||||
Protocol string
|
||||
|
||||
// Port of the traffic
|
||||
Port string
|
||||
|
||||
// id for internal purposes
|
||||
id string
|
||||
}
|
||||
|
||||
// parseFromRegoResult parses the Rego result to a FirewallRule.
|
||||
@@ -211,56 +210,46 @@ func (f *FirewallRule) parseFromRegoResult(value interface{}) error {
|
||||
return fmt.Errorf("invalid Rego query eval result peer action type")
|
||||
}
|
||||
|
||||
if v, ok := object["Protocol"]; ok {
|
||||
if protocol, ok := v.(string); ok {
|
||||
f.Protocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := object["Port"]; ok {
|
||||
if port, ok := v.(string); ok {
|
||||
f.Port = port
|
||||
}
|
||||
port, ok := object["Port"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid Rego query eval result peer port type")
|
||||
}
|
||||
|
||||
f.PeerID = peerID
|
||||
f.PeerIP = peerIP
|
||||
f.Direction = direction
|
||||
f.Action = action
|
||||
f.Port = port
|
||||
|
||||
// NOTE: update this id each time when new field added
|
||||
f.id = peerID + peerIP + direction + action + port
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRegoQuery returns a initialized Rego object with default rule.
|
||||
func (a *Account) getRegoQuery() (rego.PreparedEvalQuery, error) {
|
||||
queries := []func(*rego.Rego){
|
||||
rego.Query("data.netbird.all"),
|
||||
rego.Module("netbird", defaultPolicyModule),
|
||||
}
|
||||
for i, p := range a.Policies {
|
||||
if !p.Enabled {
|
||||
continue
|
||||
}
|
||||
queries = append(queries, rego.Module(fmt.Sprintf("netbird-%d", i), p.Query))
|
||||
}
|
||||
return rego.New(queries...).PrepareForEval(context.TODO())
|
||||
}
|
||||
|
||||
// getPeersByPolicy returns all peers that given peer has access to.
|
||||
func (a *Account) getPeersByPolicy(peerID string) ([]*Peer, []*FirewallRule) {
|
||||
// queryPeersAndFwRulesByRego returns a list associated Peers and firewall rules list for this peer.
|
||||
func (a *Account) queryPeersAndFwRulesByRego(
|
||||
peerID string,
|
||||
queryNumber int,
|
||||
query string,
|
||||
) ([]*Peer, []*FirewallRule) {
|
||||
input := map[string]interface{}{
|
||||
"peer_id": peerID,
|
||||
"peers": a.Peers,
|
||||
"groups": a.Groups,
|
||||
}
|
||||
|
||||
query, err := a.getRegoQuery()
|
||||
stmt, err := rego.New(
|
||||
rego.Query("data.netbird.all"),
|
||||
rego.Module("netbird", defaultPolicyModule),
|
||||
rego.Module(fmt.Sprintf("netbird-%d", queryNumber), query),
|
||||
).PrepareForEval(context.TODO())
|
||||
if err != nil {
|
||||
log.WithError(err).Error("get Rego query")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
evalResult, err := query.Eval(
|
||||
evalResult, err := stmt.Eval(
|
||||
context.TODO(),
|
||||
rego.EvalInput(input),
|
||||
)
|
||||
@@ -328,6 +317,33 @@ func (a *Account) getPeersByPolicy(peerID string) ([]*Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
}
|
||||
|
||||
// getPeersByPolicy returns all peers that given peer has access to.
|
||||
func (a *Account) getPeersByPolicy(peerID string) (peers []*Peer, rules []*FirewallRule) {
|
||||
peersSeen := make(map[string]struct{})
|
||||
ruleSeen := make(map[string]struct{})
|
||||
for i, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
p, r := a.queryPeersAndFwRulesByRego(peerID, i, policy.Query)
|
||||
for _, peer := range p {
|
||||
if _, ok := peersSeen[peer.ID]; ok {
|
||||
continue
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
peersSeen[peer.ID] = struct{}{}
|
||||
}
|
||||
for _, rule := range r {
|
||||
if _, ok := ruleSeen[rule.id]; ok {
|
||||
continue
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
ruleSeen[rule.id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
@@ -459,18 +475,3 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (e
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(update))
|
||||
for i := range update {
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerID: update[i].PeerID,
|
||||
PeerIP: update[i].PeerIP,
|
||||
Direction: update[i].Direction,
|
||||
Action: update[i].Action,
|
||||
Protocol: update[i].Protocol,
|
||||
Port: update[i].Port,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -5,63 +5,268 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
account := &Account{
|
||||
Peers: map[string]*Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
IP: net.IPv4(10, 20, 0, 1),
|
||||
"cfif97at2r9s73au3q00": {
|
||||
ID: "cfif97at2r9s73au3q00",
|
||||
IP: net.ParseIP("100.65.14.88"),
|
||||
},
|
||||
"peer2": {
|
||||
ID: "peer2",
|
||||
IP: net.IPv4(10, 20, 0, 2),
|
||||
"cfif97at2r9s73au3q0g": {
|
||||
ID: "cfif97at2r9s73au3q0g",
|
||||
IP: net.ParseIP("100.65.80.39"),
|
||||
},
|
||||
"peer3": {
|
||||
ID: "peer3",
|
||||
IP: net.IPv4(10, 20, 0, 3),
|
||||
"cfif97at2r9s73au3q10": {
|
||||
ID: "cfif97at2r9s73au3q10",
|
||||
IP: net.ParseIP("100.65.254.139"),
|
||||
},
|
||||
"cfif97at2r9s73au3q20": {
|
||||
ID: "cfif97at2r9s73au3q20",
|
||||
IP: net.ParseIP("100.65.62.5"),
|
||||
},
|
||||
"cfj4tiqt2r9s73dmeun0": {
|
||||
ID: "cfj4tiqt2r9s73dmeun0",
|
||||
IP: net.ParseIP("100.65.32.206"),
|
||||
},
|
||||
"cg7h032t2r9s73cg5fk0": {
|
||||
ID: "cg7h032t2r9s73cg5fk0",
|
||||
IP: net.ParseIP("100.65.250.202"),
|
||||
},
|
||||
"cgcnkj2t2r9s73cg5vv0": {
|
||||
ID: "cgcnkj2t2r9s73cg5vv0",
|
||||
IP: net.ParseIP("100.65.13.186"),
|
||||
},
|
||||
"cgcol4qt2r9s73cg601g": {
|
||||
ID: "cgcol4qt2r9s73cg601g",
|
||||
IP: net.ParseIP("100.65.29.55"),
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
"gid1": {
|
||||
ID: "gid1",
|
||||
Name: "all",
|
||||
Peers: []string{"peer1", "peer2", "peer3"},
|
||||
"cet9e92t2r9s7383ns20": {
|
||||
ID: "cet9e92t2r9s7383ns20",
|
||||
Name: "All",
|
||||
Peers: []string{
|
||||
"cfif97at2r9s73au3q0g",
|
||||
"cfif97at2r9s73au3q00",
|
||||
"cfif97at2r9s73au3q20",
|
||||
"cfif97at2r9s73au3q10",
|
||||
"cfj4tiqt2r9s73dmeun0",
|
||||
"cg7h032t2r9s73cg5fk0",
|
||||
"cgcnkj2t2r9s73cg5vv0",
|
||||
"cgcol4qt2r9s73cg601g",
|
||||
},
|
||||
},
|
||||
"cev90bat2r9s7383o150": {
|
||||
ID: "cev90bat2r9s7383o150",
|
||||
Name: "swarm",
|
||||
Peers: []string{
|
||||
"cfif97at2r9s73au3q0g",
|
||||
"cfif97at2r9s73au3q00",
|
||||
"cfif97at2r9s73au3q20",
|
||||
"cfj4tiqt2r9s73dmeun0",
|
||||
"cgcnkj2t2r9s73cg5vv0",
|
||||
"cgcol4qt2r9s73cg601g",
|
||||
},
|
||||
},
|
||||
},
|
||||
Rules: map[string]*Rule{
|
||||
"default": {
|
||||
ID: "default",
|
||||
Name: "default",
|
||||
Description: "default",
|
||||
Disabled: false,
|
||||
Source: []string{"gid1"},
|
||||
Destination: []string{"gid1"},
|
||||
"cet9e92t2r9s7383ns2g": {
|
||||
ID: "cet9e92t2r9s7383ns2g",
|
||||
Name: "Default",
|
||||
Description: "This is a default rule that allows connections between all the resources",
|
||||
Source: []string{
|
||||
"cet9e92t2r9s7383ns20",
|
||||
},
|
||||
Destination: []string{
|
||||
"cet9e92t2r9s7383ns20",
|
||||
},
|
||||
},
|
||||
"cev90bat2r9s7383o15g": {
|
||||
ID: "cev90bat2r9s7383o15g",
|
||||
Name: "Swarm",
|
||||
Description: "",
|
||||
Source: []string{
|
||||
"cev90bat2r9s7383o150",
|
||||
"cet9e92t2r9s7383ns20",
|
||||
},
|
||||
Destination: []string{
|
||||
"cev90bat2r9s7383o150",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rule, err := RuleToPolicy(account.Rules["default"])
|
||||
rule1, err := RuleToPolicy(account.Rules["cet9e92t2r9s7383ns2g"])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account.Policies = append(account.Policies, rule)
|
||||
rule2, err := RuleToPolicy(account.Rules["cev90bat2r9s7383o15g"])
|
||||
assert.NoError(t, err)
|
||||
|
||||
peers, firewallRules := account.getPeersByPolicy("peer1")
|
||||
assert.Len(t, peers, 2)
|
||||
assert.Contains(t, peers, account.Peers["peer2"])
|
||||
assert.Contains(t, peers, account.Peers["peer3"])
|
||||
account.Policies = append(account.Policies, rule1, rule2)
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
{PeerID: "peer1", PeerIP: "10.20.0.1", Direction: "dst", Action: "accept", Port: ""},
|
||||
{PeerID: "peer2", PeerIP: "10.20.0.2", Direction: "dst", Action: "accept", Port: ""},
|
||||
{PeerID: "peer3", PeerIP: "10.20.0.3", Direction: "dst", Action: "accept", Port: ""},
|
||||
{PeerID: "peer1", PeerIP: "10.20.0.1", Direction: "src", Action: "accept", Port: ""},
|
||||
{PeerID: "peer2", PeerIP: "10.20.0.2", Direction: "src", Action: "accept", Port: ""},
|
||||
{PeerID: "peer3", PeerIP: "10.20.0.3", Direction: "src", Action: "accept", Port: ""},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, firewallRules[i], epectedFirewallRules[i])
|
||||
}
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.getPeersByPolicy(p.ID)
|
||||
assert.GreaterOrEqual(t, len(peers), 2, "mininum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "mininum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeersByPolicy("cfif97at2r9s73au3q0g")
|
||||
assert.Len(t, peers, 7)
|
||||
assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q00"])
|
||||
assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q10"])
|
||||
assert.Contains(t, peers, account.Peers["cfif97at2r9s73au3q20"])
|
||||
assert.Contains(t, peers, account.Peers["cfj4tiqt2r9s73dmeun0"])
|
||||
assert.Contains(t, peers, account.Peers["cg7h032t2r9s73cg5fk0"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q00",
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q00100.65.14.88srcaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q00",
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q00100.65.14.88dstaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q0g",
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q0g100.65.80.39dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q0g",
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q0g100.65.80.39srcaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q10",
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q10100.65.254.139dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q10",
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q10100.65.254.139srcaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q20",
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q20100.65.62.5dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cfif97at2r9s73au3q20",
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfif97at2r9s73au3q20100.65.62.5srcaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cfj4tiqt2r9s73dmeun0",
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfj4tiqt2r9s73dmeun0100.65.32.206dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cfj4tiqt2r9s73dmeun0",
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cfj4tiqt2r9s73dmeun0100.65.32.206srcaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cg7h032t2r9s73cg5fk0",
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cg7h032t2r9s73cg5fk0100.65.250.202dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cg7h032t2r9s73cg5fk0",
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cg7h032t2r9s73cg5fk0100.65.250.202srcaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cgcnkj2t2r9s73cg5vv0",
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cgcnkj2t2r9s73cg5vv0100.65.13.186dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cgcnkj2t2r9s73cg5vv0",
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cgcnkj2t2r9s73cg5vv0100.65.13.186srcaccept",
|
||||
},
|
||||
|
||||
{
|
||||
PeerID: "cgcol4qt2r9s73cg601g",
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: "dst",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cgcol4qt2r9s73cg601g100.65.29.55dstaccept",
|
||||
},
|
||||
{
|
||||
PeerID: "cgcol4qt2r9s73cg601g",
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: "src",
|
||||
Action: "accept",
|
||||
Port: "",
|
||||
id: "cgcol4qt2r9s73cg601g100.65.29.55srcaccept",
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(firewallRules, func(a, b *FirewallRule) bool {
|
||||
return a.PeerID < b.PeerID
|
||||
})
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package netbird
|
||||
|
||||
all[rule] {
|
||||
is_peer_in_any_group([{{range $i, $e := .All}}{{if $i}},{{end}}"{{$e}}"{{end}}])
|
||||
rule := array.concat(
|
||||
rules_from_groups([{{range $i, $e := .Destination}}{{if $i}},{{end}}"{{$e}}"{{end}}], "dst", "accept", ""),
|
||||
rules_from_groups([{{range $i, $e := .Source}}{{if $i}},{{end}}"{{$e}}"{{end}}], "src", "accept", ""),
|
||||
)[_]
|
||||
is_peer_in_any_group([{{range $i, $e := .All}}{{if $i}},{{end}}"{{$e}}"{{end}}])
|
||||
rule := {
|
||||
{{range $i, $e := .Destination}}rules_from_group("{{$e}}", "dst", "accept", ""),{{end}}
|
||||
{{range $i, $e := .Source}}rules_from_group("{{$e}}", "src", "accept", ""),{{end}}
|
||||
}[_][_]
|
||||
}
|
||||
|
||||
@@ -17,17 +17,11 @@ get_rule(peer_id, direction, action, port) := rule if {
|
||||
}
|
||||
}
|
||||
|
||||
# peers_from_group returns a list of peer ids for a given group id
|
||||
peers_from_group(group_id) := peers if {
|
||||
# netbird_rules_from_group returns a list of netbird rules for a given group_id
|
||||
rules_from_group(group_id, direction, action, port) := rules if {
|
||||
group := input.groups[_]
|
||||
group.ID == group_id
|
||||
peers := [peer | peer := group.Peers[_]]
|
||||
}
|
||||
|
||||
# netbird_rules_from_groups returns a list of netbird rules for a given list of group names
|
||||
rules_from_groups(groups, direction, action, port) := rules if {
|
||||
group_id := groups[_]
|
||||
rules := [get_rule(peer, direction, action, port) | peer := peers_from_group(group_id)[_]]
|
||||
rules := [get_rule(peer, direction, action, port) | peer := group.Peers[_]]
|
||||
}
|
||||
|
||||
# is_peer_in_any_group checks that input peer present at least in one group
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -130,7 +132,7 @@ func (key *SetupKey) HiddenCopy(length int) *SetupKey {
|
||||
func (key *SetupKey) IncrementUsage() *SetupKey {
|
||||
c := key.Copy()
|
||||
c.UsedTimes = c.UsedTimes + 1
|
||||
c.LastUsed = time.Now()
|
||||
c.LastUsed = time.Now().UTC()
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -171,9 +173,9 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoG
|
||||
Key: key,
|
||||
Name: name,
|
||||
Type: t,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(validFor),
|
||||
UpdatedAt: time.Now(),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(validFor),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
Revoked: false,
|
||||
UsedTimes: 0,
|
||||
AutoGroups: autoGroups,
|
||||
@@ -274,7 +276,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now()
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
|
||||
account.SetupKeys[newKey.Key] = newKey
|
||||
|
||||
|
||||
@@ -2,12 +2,14 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
@@ -54,7 +56,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
key.Id, time.Now(), autoGroups)
|
||||
key.Id, time.Now().UTC(), autoGroups)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked)
|
||||
@@ -108,10 +110,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
expectedCreatedAt time.Time
|
||||
expectedUpdatedAt time.Time
|
||||
expectedExpiresAt time.Time
|
||||
expectedFailure bool //indicates whether key creation should fail
|
||||
expectedFailure bool // indicates whether key creation should fail
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
now := time.Now().UTC()
|
||||
expiresIn := time.Hour
|
||||
testCase1 := testCase{
|
||||
name: "Should Create Setup Key successfully",
|
||||
@@ -169,9 +171,9 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
|
||||
expectedRevoke := false
|
||||
expectedType := "reusable"
|
||||
expectedUsedTimes := 0
|
||||
expectedCreatedAt := time.Now()
|
||||
expectedUpdatedAt := time.Now()
|
||||
expectedExpiresAt := time.Now().Add(24 * 30 * time.Hour)
|
||||
expectedCreatedAt := time.Now().UTC()
|
||||
expectedUpdatedAt := time.Now().UTC()
|
||||
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key := GenerateDefaultSetupKey()
|
||||
@@ -186,9 +188,9 @@ func TestGenerateSetupKey(t *testing.T) {
|
||||
expectedRevoke := false
|
||||
expectedType := "one-off"
|
||||
expectedUsedTimes := 0
|
||||
expectedCreatedAt := time.Now()
|
||||
expectedExpiresAt := time.Now().Add(time.Hour)
|
||||
expectedUpdatedAt := time.Now()
|
||||
expectedCreatedAt := time.Now().UTC()
|
||||
expectedExpiresAt := time.Now().UTC().Add(time.Hour)
|
||||
expectedUpdatedAt := time.Now().UTC()
|
||||
var expectedAutoGroups []string
|
||||
|
||||
key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage)
|
||||
|
||||
@@ -6,9 +6,13 @@ type Store interface {
|
||||
GetAccountByUser(userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||
GetAccountByPeerID(peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error) //todo use key hash later
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(secret string) (string, error)
|
||||
GetUserByTokenID(tokenID string) (*User, error)
|
||||
SaveAccount(account *Account) error
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ID string) error
|
||||
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
|
||||
|
||||
16
management/server/testdata/store.json
vendored
16
management/server/testdata/store.json
vendored
@@ -29,11 +29,23 @@
|
||||
"Users": {
|
||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
"Role": "admin"
|
||||
"Role": "admin",
|
||||
"PATs": {}
|
||||
},
|
||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
||||
"Role": "user"
|
||||
"Role": "user",
|
||||
"PATs": {
|
||||
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
||||
"ID":"9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||
"Description":"some Description",
|
||||
"HashedToken":"SoMeHaShEdToKeN",
|
||||
"ExpirationDate":"2023-02-27T00:00:00Z",
|
||||
"CreatedBy":"user",
|
||||
"CreatedAt":"2023-01-01T00:00:00Z",
|
||||
"LastUsed":"2023-02-01T00:00:00Z"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,10 +5,12 @@ import (
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
// TURNCredentialsManager used to manage TURN credentials
|
||||
@@ -88,7 +90,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
|
||||
log.Debugf("starting turn refresh for %s", peerID)
|
||||
|
||||
go func() {
|
||||
//we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
||||
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
|
||||
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
|
||||
|
||||
for {
|
||||
|
||||
@@ -46,7 +46,7 @@ type User struct {
|
||||
Role UserRole
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string
|
||||
PATs []PersonalAccessToken
|
||||
PATs map[string]*PersonalAccessToken
|
||||
}
|
||||
|
||||
// IsAdmin returns true if user is an admin, false otherwise
|
||||
@@ -94,8 +94,12 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
func (u *User) Copy() *User {
|
||||
autoGroups := make([]string, len(u.AutoGroups))
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
pats := make([]PersonalAccessToken, len(u.PATs))
|
||||
copy(pats, u.PATs)
|
||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||
for k, v := range u.PATs {
|
||||
patCopy := new(PersonalAccessToken)
|
||||
*patCopy = *v
|
||||
pats[k] = patCopy
|
||||
}
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
Role: u.Role,
|
||||
@@ -189,6 +193,150 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us
|
||||
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if tokenName == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
||||
}
|
||||
|
||||
if expiresIn < 1 || expiresIn > 365 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||
}
|
||||
|
||||
if executingUserID != targetUserId {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserId]
|
||||
if targetUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "targetUser not found")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||
}
|
||||
|
||||
targetUser.PATs[pat.ID] = &pat.PersonalAccessToken
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name}
|
||||
am.storeEvent(executingUserID, targetUserId, accountID, activity.PersonalAccessTokenCreated, meta)
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
err = am.Store.DeleteTokenID2UserIDIndex(pat.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "Failed to delete token id index: %s", err)
|
||||
}
|
||||
err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name}
|
||||
am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
|
||||
delete(user.PATs, tokenID)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "Failed to save account: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
return pat, nil
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if executingUserID != targetUserID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
}
|
||||
|
||||
user := account.Users[targetUserID]
|
||||
if user == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
var pats []*PersonalAccessToken
|
||||
for _, pat := range user.PATs {
|
||||
pats = append(pats, pat)
|
||||
}
|
||||
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
|
||||
// Only User.AutoGroups field is allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user