mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 14:34:54 -04:00
Compare commits
9 Commits
main
...
feat/byod-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de3cb06067 | ||
|
|
4fdc39c8f8 | ||
|
|
94149a9441 | ||
|
|
38fd73fad6 | ||
|
|
9dd76b5a07 | ||
|
|
0b5380a7dc | ||
|
|
177171e437 | ||
|
|
da57b0f276 | ||
|
|
26ba03f08e |
@@ -31,6 +31,7 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
type clusterCapabilities interface {
|
||||
@@ -79,8 +80,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
var ret []*domain.Domain
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
@@ -134,8 +135,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
// Verify the target cluster is in the available clusters for this account
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -281,7 +282,7 @@ func (m Manager) GetClusterDomains() []string {
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
@@ -306,6 +307,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
|
||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||
}
|
||||
|
||||
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
|
||||
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||
}
|
||||
if len(byopAddresses) > 0 {
|
||||
return byopAddresses, nil
|
||||
}
|
||||
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
}
|
||||
|
||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||
bestCluster := ""
|
||||
bestLen := -1
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockProxyManager struct {
|
||||
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveClusterAddressesFunc != nil {
|
||||
return m.getActiveClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveClusterAddressesForAccountFunc != nil {
|
||||
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return []string{"byop.example.com"}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, errors.New("db error")
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||
}
|
||||
|
||||
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||
pm := &mockProxyManager{
|
||||
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
},
|
||||
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||
return []string{"eu.proxy.netbird.io"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := Manager{proxyManager: pm}
|
||||
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||
}
|
||||
|
||||
@@ -11,12 +11,17 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteProxy(ctx context.Context, proxyID string) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
|
||||
@@ -13,10 +13,16 @@ import (
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
DisconnectProxy(ctx context.Context, proxyID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteProxy(ctx context.Context, proxyID string) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
@@ -39,15 +45,16 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
AccountID: accountID,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
@@ -65,16 +72,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
func (m *Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
@@ -87,7 +86,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
return err
|
||||
@@ -99,7 +98,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddre
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
@@ -119,10 +118,44 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !conflicting, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.DeleteProxy(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||
disconnectProxyFunc func(ctx context.Context, proxyID string) error
|
||||
updateProxyHeartbeatFunc func(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
deleteProxyFunc func(ctx context.Context, proxyID string) error
|
||||
}
|
||||
|
||||
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
if m.saveProxyFunc != nil {
|
||||
return m.saveProxyFunc(ctx, p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||
if m.disconnectProxyFunc != nil {
|
||||
return m.disconnectProxyFunc(ctx, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
if m.updateProxyHeartbeatFunc != nil {
|
||||
return m.updateProxyHeartbeatFunc(ctx, proxyID, clusterAddress, ipAddress)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
if m.cleanupStaleProxiesFunc != nil {
|
||||
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
if m.getProxyByAccountIDFunc != nil {
|
||||
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
if m.countProxiesByAccountIDFunc != nil {
|
||||
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
if m.isClusterAddressConflictingFunc != nil {
|
||||
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockStore) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
if m.deleteProxyFunc != nil {
|
||||
return m.deleteProxyFunc(ctx, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestManager(s store) *Manager {
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
m, err := NewManager(s, meter)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func TestConnect_WithAccountID(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||
}
|
||||
|
||||
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||
var savedProxy *proxy.Proxy
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||
savedProxy = p
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, savedProxy)
|
||||
assert.Nil(t, savedProxy.AccountID)
|
||||
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||
}
|
||||
|
||||
func TestConnect_StoreError(t *testing.T) {
|
||||
s := &mockStore{
|
||||
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
conflicting bool
|
||||
storeErr error
|
||||
wantResult bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "available - no conflict",
|
||||
conflicting: false,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "not available - conflict exists",
|
||||
conflicting: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||
return tt.conflicting, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountAccountProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
storeErr error
|
||||
wantCount int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no proxies",
|
||||
count: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
{
|
||||
name: "one proxy",
|
||||
count: 1,
|
||||
wantCount: 1,
|
||||
},
|
||||
{
|
||||
name: "store error",
|
||||
storeErr: errors.New("db error"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||
return tt.count, tt.storeErr
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantCount, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxy(t *testing.T) {
|
||||
accountID := "acc-123"
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
expected := &proxy.Proxy{
|
||||
ID: "proxy-1",
|
||||
ClusterAddress: "byop.example.com",
|
||||
AccountID: &accountID,
|
||||
Status: proxy.StatusConnected,
|
||||
}
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||
assert.Equal(t, accountID, accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, p)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, errors.New("not found")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteProxy(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
var deletedID string
|
||||
s := &mockStore{
|
||||
deleteProxyFunc: func(_ context.Context, proxyID string) error {
|
||||
deletedID = proxyID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteProxy(context.Background(), "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "proxy-1", deletedID)
|
||||
})
|
||||
|
||||
t.Run("store error", func(t *testing.T) {
|
||||
s := &mockStore{
|
||||
deleteProxyFunc: func(_ context.Context, _ string) error {
|
||||
return errors.New("db error")
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
err := mgr.DeleteProxy(context.Background(), "proxy-1")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||
expected := []string{"byop.example.com"}
|
||||
s := &mockStore{
|
||||
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||
assert.Equal(t, "acc-123", accID)
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgr := newTestManager(s)
|
||||
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
@@ -51,17 +51,17 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -93,7 +93,19 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
|
||||
@@ -122,6 +134,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAdd
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||
}
|
||||
|
||||
// GetAccountProxy mocks base method.
|
||||
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAccountProxy indicates an expected call of GetAccountProxy.
|
||||
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
|
||||
}
|
||||
|
||||
// CountAccountProxies mocks base method.
|
||||
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAccountProxies indicates an expected call of CountAccountProxies.
|
||||
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable mocks base method.
|
||||
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
|
||||
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteProxy mocks base method.
|
||||
func (m *MockManager) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteProxy indicates an expected call of DeleteProxy.
|
||||
func (mr *MockManagerMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockManager)(nil).DeleteProxy), ctx, proxyID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
package proxy
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrAccountProxyAlreadyExists = errors.New("account already has a registered proxy")
|
||||
|
||||
const (
|
||||
StatusConnected = "connected"
|
||||
StatusDisconnected = "disconnected"
|
||||
)
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
AccountID *string `gorm:"type:varchar(255);uniqueIndex:idx_proxy_account_id_unique"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
@@ -21,6 +32,7 @@ func (Proxy) TableName() string {
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
}
|
||||
|
||||
195
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
195
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{store: s, permissionsManager: permissionsManager}
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.ProxyTokenRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" || len(req.Name) > 255 {
|
||||
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
var expiresIn time.Duration
|
||||
if req.ExpiresIn != nil {
|
||||
if *req.ExpiresIn < 0 {
|
||||
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
if *req.ExpiresIn > 0 {
|
||||
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
accountID := userAuth.AccountId
|
||||
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toProxyTokenCreatedResponse(generated)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.ProxyToken, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
resp = append(resp, toProxyTokenResponse(token))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := mux.Vars(r)["tokenId"]
|
||||
if tokenID == "" {
|
||||
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
} else {
|
||||
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
|
||||
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||
resp := api.ProxyToken{
|
||||
Id: token.ID,
|
||||
Name: token.Name,
|
||||
Revoked: token.Revoked,
|
||||
}
|
||||
if !token.CreatedAt.IsZero() {
|
||||
resp.CreatedAt = token.CreatedAt
|
||||
}
|
||||
if token.ExpiresAt != nil {
|
||||
resp.ExpiresAt = token.ExpiresAt
|
||||
}
|
||||
if token.LastUsed != nil {
|
||||
resp.LastUsed = token.LastUsed
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
|
||||
base := toProxyTokenResponse(&generated.ProxyAccessToken)
|
||||
plainToken := string(generated.PlainToken)
|
||||
return api.ProxyTokenCreated{
|
||||
Id: base.Id,
|
||||
Name: base.Name,
|
||||
CreatedAt: base.CreatedAt,
|
||||
ExpiresAt: base.ExpiresAt,
|
||||
LastUsed: base.LastUsed,
|
||||
Revoked: base.Revoked,
|
||||
PlainToken: plainToken,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
package proxytoken
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func authContext(accountID, userID string) context.Context {
|
||||
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
UserId: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateToken_AccountScoped(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "my-token"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp api.ProxyTokenCreated
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
|
||||
assert.NotEmpty(t, resp.PlainToken)
|
||||
assert.Equal(t, "my-token", resp.Name)
|
||||
assert.False(t, resp.Revoked)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.AccountID)
|
||||
assert.Equal(t, accountID, *savedToken.AccountID)
|
||||
assert.Equal(t, "user-1", savedToken.CreatedBy)
|
||||
}
|
||||
|
||||
func TestCreateToken_WithExpiration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var savedToken *types.ProxyAccessToken
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||
savedToken = token
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "expiring-token", "expires_in": 3600}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
require.NotNil(t, savedToken)
|
||||
require.NotNil(t, savedToken.ExpiresAt)
|
||||
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
|
||||
}
|
||||
|
||||
func TestCreateToken_EmptyName(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": ""}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||
|
||||
h := &handler{
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
body := `{"name": "test"}`
|
||||
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.createToken(w, req)
|
||||
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||
}
|
||||
|
||||
func TestListTokens(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
now := time.Now()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
|
||||
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
|
||||
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.listTokens(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var resp []api.ProxyToken
|
||||
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||
require.Len(t, resp, 2)
|
||||
assert.Equal(t, "tok-1", resp[0].Id)
|
||||
assert.False(t, resp[0].Revoked)
|
||||
assert.Equal(t, "tok-2", resp[1].Id)
|
||||
assert.True(t, resp[1].Revoked)
|
||||
}
|
||||
|
||||
func TestRevokeToken_Success(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
accountID := "acc-123"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
Name: "test-token",
|
||||
AccountID: &accountID,
|
||||
}, nil)
|
||||
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext(accountID, "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
otherAccount := "acc-other"
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: &otherAccount,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||
ID: "tok-1",
|
||||
AccountID: nil,
|
||||
}, nil)
|
||||
|
||||
permsMgr := permissions.NewMockManager(ctrl)
|
||||
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||
|
||||
h := &handler{
|
||||
store: mockStore,
|
||||
permissionsManager: permsMgr,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.revokeToken(w, req)
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
@@ -28,4 +28,5 @@ type Manager interface {
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
|
||||
}
|
||||
|
||||
@@ -138,6 +138,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -195,6 +195,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
})
|
||||
|
||||
@@ -924,6 +924,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
|
||||
@@ -426,7 +426,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -707,7 +707,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1132,7 +1132,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -169,7 +169,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -47,6 +48,11 @@ type ProxyOIDCConfig struct {
|
||||
KeysLocation string
|
||||
}
|
||||
|
||||
// ProxyTokenChecker checks whether a proxy access token is still valid.
|
||||
type ProxyTokenChecker interface {
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -75,6 +81,9 @@ type ProxyServiceServer struct {
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
// Checker for proxy access token validity
|
||||
tokenChecker ProxyTokenChecker
|
||||
|
||||
// OIDC configuration for proxy authentication
|
||||
oidcConfig ProxyOIDCConfig
|
||||
|
||||
@@ -90,6 +99,8 @@ const pkceVerifierTTL = 10 * time.Minute
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
address string
|
||||
accountID *string
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
@@ -97,8 +108,19 @@ type proxyConnection struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token == nil || token.AccountID == nil {
|
||||
return nil
|
||||
}
|
||||
if requestAccountID == "" || *token.AccountID != requestAccountID {
|
||||
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -108,6 +130,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
@@ -166,10 +189,48 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
existingProxy, err := s.proxyManager.GetAccountProxy(ctx, *accountID)
|
||||
if err != nil {
|
||||
if s, ok := nbstatus.FromError(err); ok && s.ErrorType == nbstatus.NotFound {
|
||||
log.WithContext(ctx).Debugf("no existing BYOP proxy for account %s", *accountID)
|
||||
} else {
|
||||
return status.Errorf(codes.Internal, "failed to check existing proxy: %v", err)
|
||||
}
|
||||
}
|
||||
if existingProxy != nil && existingProxy.ID != proxyID {
|
||||
if existingProxy.Status == proxy.StatusConnected {
|
||||
return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account")
|
||||
}
|
||||
if err := s.proxyManager.DeleteProxy(ctx, existingProxy.ID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to cleanup disconnected proxy %s: %v", existingProxy.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
@@ -177,20 +238,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID); err != nil {
|
||||
if accountID != nil {
|
||||
cancel()
|
||||
if errors.Is(err, proxy.ErrAccountProxyAlreadyExists) {
|
||||
return status.Errorf(codes.ResourceExhausted, "limit of 1 self-hosted proxy per account")
|
||||
}
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
@@ -215,7 +283,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
||||
go s.heartbeat(connCtx, conn, peerInfo)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
@@ -225,16 +293,28 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, ipAddress string) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
if err := s.proxyManager.Heartbeat(ctx, conn.proxyID, conn.address, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
if conn.tokenID != "" && s.tokenChecker != nil {
|
||||
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
@@ -242,8 +322,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddr
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only entries matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
@@ -276,7 +354,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
var services []*rpservice.Service
|
||||
var err error
|
||||
if conn.accountID != nil {
|
||||
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
|
||||
} else {
|
||||
services, err = s.serviceManager.GetGlobalServices(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -302,8 +386,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
return mappings, nil
|
||||
}
|
||||
|
||||
// isProxyAddressValid validates a proxy address
|
||||
// isProxyAddressValid validates a proxy address (domain name or IP address)
|
||||
func isProxyAddressValid(addr string) bool {
|
||||
if addr == "" {
|
||||
return false
|
||||
}
|
||||
if net.ParseIP(addr) != nil {
|
||||
return true
|
||||
}
|
||||
_, err := domain.ValidateDomains([]string{addr})
|
||||
return err == nil
|
||||
}
|
||||
@@ -327,6 +417,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
|
||||
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fields := log.Fields{
|
||||
"service_id": accessLog.GetServiceId(),
|
||||
"account_id": accessLog.GetAccountId(),
|
||||
@@ -364,11 +458,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
||||
// Management should call this when services are created/updated/removed.
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
// BYOP proxies only receive updates for their own account's services.
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
updateAccountIDs := make(map[string]struct{})
|
||||
for _, m := range update.Mapping {
|
||||
if m.AccountId != "" {
|
||||
updateAccountIDs[m.AccountId] = struct{}{}
|
||||
}
|
||||
}
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
connUpdate := update
|
||||
if conn.accountID != nil && len(updateAccountIDs) > 0 {
|
||||
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
|
||||
return true
|
||||
}
|
||||
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
|
||||
if len(filtered) == 0 {
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
if resp == nil {
|
||||
return true
|
||||
}
|
||||
@@ -382,6 +497,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
})
|
||||
}
|
||||
|
||||
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
|
||||
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
conn.cancel()
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
|
||||
}
|
||||
}
|
||||
|
||||
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
|
||||
var filtered []*proto.ProxyMapping
|
||||
for _, m := range mappings {
|
||||
if m.AccountId == accountID {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// GetConnectedProxies returns a list of connected proxy IDs
|
||||
func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
||||
var proxies []string
|
||||
@@ -447,6 +582,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
for _, proxyID := range proxyIDs {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
|
||||
continue
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
continue
|
||||
@@ -567,6 +705,10 @@ func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
@@ -686,6 +828,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
||||
|
||||
// SendStatusUpdate handles status updates from proxy clients.
|
||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountID := req.GetAccountId()
|
||||
serviceID := req.GetServiceId()
|
||||
protoStatus := req.GetStatus()
|
||||
@@ -756,6 +902,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
|
||||
// CreateProxyPeer handles proxy peer creation with one-time token authentication
|
||||
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceID := req.GetServiceId()
|
||||
accountID := req.GetAccountId()
|
||||
token := req.GetToken()
|
||||
@@ -810,6 +960,10 @@ func strPtr(s string) *string {
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
@@ -938,21 +1092,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
service, err := s.getServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *rpservice.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
break
|
||||
}
|
||||
}
|
||||
if service == nil {
|
||||
return "", fmt.Errorf("service not found for domain: %s", domain)
|
||||
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if service.SessionPrivateKey == "" {
|
||||
@@ -1050,6 +1192,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
@@ -1133,18 +1279,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if service.Domain == domain {
|
||||
return service, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
|
||||
29
management/internals/shared/grpc/proxy_address_test.go
Normal file
29
management/internals/shared/grpc/proxy_address_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsProxyAddressValid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
valid bool
|
||||
}{
|
||||
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
|
||||
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
|
||||
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
|
||||
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
|
||||
{name: "valid IPv6", addr: "::1", valid: true},
|
||||
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
|
||||
{name: "empty string", addr: "", valid: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
||||
|
||||
if !token.IsValid() {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||
}
|
||||
|
||||
@@ -91,6 +91,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
|
||||
|
||||
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
for _, services := range m.proxiesByAccount {
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,11 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -323,6 +326,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func scopedCtx(accountID string) context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-1",
|
||||
AccountID: &accountID,
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func globalCtx() context.Context {
|
||||
token := &types.ProxyAccessToken{
|
||||
ID: "token-global",
|
||||
}
|
||||
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
|
||||
err := enforceAccountScope(scopedCtx("acc-1"), "")
|
||||
require.Error(t, err)
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
|
||||
err := enforceAccountScope(globalCtx(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "acc-2")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = enforceAccountScope(globalCtx(), "")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
|
||||
err := enforceAccountScope(context.Background(), "acc-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
|
||||
@@ -45,7 +45,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
@@ -321,13 +321,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
|
||||
|
||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -335,7 +339,7 @@ func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -343,6 +347,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -351,6 +359,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionUsersManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
@@ -3133,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -176,6 +177,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
if proxyGRPCServer != nil {
|
||||
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
||||
|
||||
@@ -215,6 +215,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
@@ -434,6 +435,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
|
||||
|
||||
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -4465,6 +4466,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var tokens []*types.ProxyAccessToken
|
||||
result := tx.Where("account_id = ?", accountID).Find(&tokens)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return token.IsValid(), nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var token types.ProxyAccessToken
|
||||
result := tx.Take(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
|
||||
@@ -5402,18 +5444,49 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
if isUniqueConstraintError(result.Error) {
|
||||
return proxy.ErrAccountProxyAlreadyExists
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return true
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "UNIQUE constraint") ||
|
||||
strings.Contains(errStr, "duplicate key") ||
|
||||
strings.Contains(errStr, "Duplicate entry") ||
|
||||
strings.Contains(errStr, "Error 1062")
|
||||
}
|
||||
|
||||
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ?", proxyID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": proxy.StatusDisconnected,
|
||||
"disconnected_at": now,
|
||||
"last_seen": now,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Where("id = ? AND status = ?", proxyID, proxy.StatusConnected).
|
||||
Update("last_seen", now)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -5445,7 +5518,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
@@ -5457,13 +5530,72 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
var p proxy.Proxy
|
||||
result := s.db.WithContext(ctx).Where("account_id = ?", accountID).Take(&p)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "proxy not found for account")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.WithContext(ctx).Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
var count int64
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).Where(idQueryCondition, proxyID).Delete(&proxy.Proxy{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "delete proxy: %v", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "proxy not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
var clusters []proxy.Cluster
|
||||
|
||||
result := s.db.Model(&proxy.Proxy{}).
|
||||
Select("cluster_address as address, COUNT(*) as connected_proxies").
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies").
|
||||
Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
|
||||
Group("cluster_address").
|
||||
Scan(&clusters)
|
||||
|
||||
|
||||
@@ -114,6 +114,9 @@ type Store interface {
|
||||
|
||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
||||
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
|
||||
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
|
||||
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
||||
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||
@@ -284,10 +287,16 @@ type Store interface {
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
DisconnectProxy(ctx context.Context, proxyID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
DeleteProxy(ctx context.Context, proxyID string) error
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
}
|
||||
|
||||
@@ -165,6 +165,21 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// CountProxiesByAccountID mocks base method.
|
||||
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1287,7 +1302,19 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
|
||||
}
|
||||
|
||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
|
||||
@@ -1296,7 +1323,6 @@ func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
|
||||
@@ -1346,6 +1372,51 @@ func (mr *MockStoreMockRecorder) GetAllProxyAccessTokens(ctx, lockStrength inter
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxyAccessTokens", reflect.TypeOf((*MockStore)(nil).GetAllProxyAccessTokens), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetProxyAccessTokensByAccountID mocks base method.
|
||||
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByID mocks base method.
|
||||
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
|
||||
ret0, _ := ret[0].(*types2.ProxyAccessToken)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
|
||||
}
|
||||
|
||||
// IsProxyAccessTokenValid mocks base method.
|
||||
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
|
||||
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
|
||||
}
|
||||
|
||||
// GetAnyAccountID mocks base method.
|
||||
func (m *MockStore) GetAnyAccountID(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1916,6 +1987,50 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
|
||||
}
|
||||
|
||||
// GetProxyByAccountID mocks base method.
|
||||
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
|
||||
ret0, _ := ret[0].(*proxy.Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
|
||||
}
|
||||
|
||||
// IsClusterAddressConflicting mocks base method.
|
||||
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
|
||||
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
|
||||
}
|
||||
|
||||
// DeleteProxy mocks base method.
|
||||
func (m *MockStore) DeleteProxy(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteProxy indicates an expected call of DeleteProxy.
|
||||
func (mr *MockStoreMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockStore)(nil).DeleteProxy), ctx, proxyID)
|
||||
}
|
||||
|
||||
// GetResourceGroups mocks base method.
|
||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2743,6 +2858,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
408
proxy/management_byop_integration_test.go
Normal file
408
proxy/management_byop_integration_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type byopTestSetup struct {
|
||||
store store.Store
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
grpcServer *grpc.Server
|
||||
grpcAddr string
|
||||
cleanup func()
|
||||
|
||||
accountA string
|
||||
accountB string
|
||||
accountAToken types.PlainProxyToken
|
||||
accountBToken types.PlainProxyToken
|
||||
accountACluster string
|
||||
accountBCluster string
|
||||
}
|
||||
|
||||
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountAID := "byop-account-a"
|
||||
accountBID := "byop-account-b"
|
||||
|
||||
for _, acc := range []*types.Account{
|
||||
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||
} {
|
||||
require.NoError(t, testStore.SaveAccount(ctx, acc))
|
||||
}
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
clusterA := "byop-a.proxy.test"
|
||||
clusterB := "byop-b.proxy.test"
|
||||
|
||||
services := []*service.Service{
|
||||
{
|
||||
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
|
||||
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
{
|
||||
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
|
||||
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
{
|
||||
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
|
||||
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
|
||||
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
|
||||
},
|
||||
}
|
||||
for _, svc := range services {
|
||||
require.NoError(t, testStore.CreateService(ctx, svc))
|
||||
}
|
||||
|
||||
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
|
||||
|
||||
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
meter := noop.NewMeterProvider().Meter("test")
|
||||
realProxyManager, err := proxymanager.NewManager(testStore, meter)
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||
Issuer: "https://fake-issuer.example.com",
|
||||
ClientID: "test-client",
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
}
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
pkceStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
realProxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
||||
proxyService.SetServiceManager(svcMgr)
|
||||
|
||||
proxyController := &testProxyController{}
|
||||
proxyService.SetProxyController(proxyController)
|
||||
|
||||
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
|
||||
proto.RegisterProxyServiceServer(grpcServer, proxyService)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
t.Logf("gRPC server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &byopTestSetup{
|
||||
store: testStore,
|
||||
proxyService: proxyService,
|
||||
grpcServer: grpcServer,
|
||||
grpcAddr: lis.Addr().String(),
|
||||
cleanup: func() {
|
||||
grpcServer.GracefulStop()
|
||||
authClose()
|
||||
storeCleanup()
|
||||
},
|
||||
accountA: accountAID,
|
||||
accountB: accountBID,
|
||||
accountAToken: tokenA.PlainToken,
|
||||
accountBToken: tokenB.PlainToken,
|
||||
accountACluster: clusterA,
|
||||
accountBCluster: clusterB,
|
||||
}
|
||||
}
|
||||
|
||||
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
|
||||
md := metadata.Pairs("authorization", "Bearer "+string(token))
|
||||
return metadata.NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||
t.Helper()
|
||||
var mappings []*proto.ProxyMapping
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
if msg.GetInitialSyncComplete() {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings := receiveBYOPMappings(t, stream)
|
||||
|
||||
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
|
||||
for _, m := range mappings {
|
||||
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
|
||||
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
|
||||
}
|
||||
|
||||
ids := map[string]bool{}
|
||||
for _, m := range mappings {
|
||||
ids[m.GetId()] = true
|
||||
}
|
||||
assert.True(t, ids["svc-a1"], "should contain svc-a1")
|
||||
assert.True(t, ids["svc-a2"], "should contain svc-a2")
|
||||
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-b",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountBCluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mappings := receiveBYOPMappings(t, stream)
|
||||
|
||||
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
|
||||
assert.Equal(t, "svc-b1", mappings[0].GetId())
|
||||
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_LimitOnePerAccount(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-first",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = receiveBYOPMappings(t, stream1)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-second",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream2.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.ResourceExhausted, st.Code(), "second BYOP proxy should be rejected with ResourceExhausted")
|
||||
t.Logf("expected rejection: %s", st.Message())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-a-cluster",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = receiveBYOPMappings(t, stream1)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "byop-proxy-b-conflict",
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream2.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
|
||||
t.Logf("expected rejection: %s", st.Message())
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
proxyID := "byop-proxy-reconnect"
|
||||
|
||||
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveBYOPMappings(t, stream1)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: setup.accountACluster,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveBYOPMappings(t, stream2)
|
||||
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
|
||||
|
||||
firstIDs := map[string]bool{}
|
||||
for _, m := range firstMappings {
|
||||
firstIDs[m.GetId()] = true
|
||||
}
|
||||
for _, m := range secondMappings {
|
||||
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
|
||||
setup := setupBYOPIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "no-auth-proxy",
|
||||
Version: "test-v1",
|
||||
Address: "some.cluster.io",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.Error(t, err)
|
||||
|
||||
st, ok := grpcstatus.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, codes.Unauthenticated, st.Code())
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -139,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
nil,
|
||||
usersManager,
|
||||
proxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
// Use store-backed service manager
|
||||
@@ -200,7 +202,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error {
|
||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -216,6 +218,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -224,6 +230,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
|
||||
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||
}
|
||||
|
||||
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) DeleteProxy(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
|
||||
type testProxyController struct{}
|
||||
|
||||
@@ -331,6 +353,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
|
||||
|
||||
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
return m.store.GetServiceByDomain(ctx, domain)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -3289,10 +3289,64 @@ components:
|
||||
example: false
|
||||
required:
|
||||
- enabled
|
||||
ProxyTokenRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Human-readable token name
|
||||
example: "my-proxy-token"
|
||||
expires_in:
|
||||
type: integer
|
||||
minimum: 0
|
||||
description: Token expiration in seconds (0 = never expires)
|
||||
example: 0
|
||||
required:
|
||||
- name
|
||||
ProxyToken:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
expires_at:
|
||||
type: string
|
||||
format: date-time
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
last_used:
|
||||
type: string
|
||||
format: date-time
|
||||
revoked:
|
||||
type: boolean
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
- created_at
|
||||
- revoked
|
||||
ProxyTokenCreated:
|
||||
type: object
|
||||
description: Returned on creation — plain_token is shown only once
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/ProxyToken'
|
||||
- type: object
|
||||
properties:
|
||||
plain_token:
|
||||
type: string
|
||||
description: The plain text token (shown only once)
|
||||
example: "nbx_abc123..."
|
||||
required:
|
||||
- plain_token
|
||||
ProxyCluster:
|
||||
type: object
|
||||
description: A proxy cluster represents a group of proxy nodes serving the same address
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
description: Unique identifier of a proxy in this cluster
|
||||
example: "chlfq4q5r8kc73b0qjpg"
|
||||
address:
|
||||
type: string
|
||||
description: Cluster address used for CNAME targets
|
||||
@@ -3301,9 +3355,15 @@ components:
|
||||
type: integer
|
||||
description: Number of proxy nodes connected in this cluster
|
||||
example: 3
|
||||
self_hosted:
|
||||
type: boolean
|
||||
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||
example: false
|
||||
required:
|
||||
- id
|
||||
- address
|
||||
- connected_proxies
|
||||
- self_hosted
|
||||
ReverseProxyDomainType:
|
||||
type: string
|
||||
description: Type of Reverse Proxy Domain
|
||||
@@ -9798,6 +9858,111 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/clusters/{clusterId}:
|
||||
delete:
|
||||
summary: Delete a self-hosted proxy cluster
|
||||
description: Removes a self-hosted (BYOP) proxy cluster and disconnects it. Only self-hosted clusters can be deleted.
|
||||
tags: [ Services ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: clusterId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of the proxy cluster
|
||||
responses:
|
||||
'200':
|
||||
description: Proxy cluster deleted successfully
|
||||
content: { }
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/proxy-tokens:
|
||||
get:
|
||||
summary: List Proxy Tokens
|
||||
description: Returns all proxy access tokens for the account
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON Array of proxy tokens
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ProxyToken'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a Proxy Token
|
||||
description: Generate an account-scoped proxy access token for self-hosted proxy registration
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProxyTokenRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Proxy token created (plain token shown once)
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ProxyTokenCreated'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/proxy-tokens/{tokenId}:
|
||||
delete:
|
||||
summary: Revoke a Proxy Token
|
||||
description: Revoke an account-scoped proxy access token
|
||||
tags: [ Self-Hosted Proxies ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: tokenId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of the proxy token
|
||||
responses:
|
||||
'200':
|
||||
description: Token revoked
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
"$ref": "#/components/responses/not_found"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/reverse-proxies/services:
|
||||
get:
|
||||
summary: List all Services
|
||||
|
||||
@@ -3381,11 +3381,49 @@ type ProxyAccessLogsResponse struct {
|
||||
|
||||
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
|
||||
type ProxyCluster struct {
|
||||
// Id Unique identifier of a proxy in this cluster
|
||||
Id string `json:"id"`
|
||||
|
||||
// Address Cluster address used for CNAME targets
|
||||
Address string `json:"address"`
|
||||
|
||||
// ConnectedProxies Number of proxy nodes connected in this cluster
|
||||
ConnectedProxies int `json:"connected_proxies"`
|
||||
|
||||
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||
SelfHosted bool `json:"self_hosted"`
|
||||
}
|
||||
|
||||
// ProxyToken defines model for ProxyToken.
|
||||
type ProxyToken struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Id string `json:"id"`
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
// ProxyTokenCreated defines model for ProxyTokenCreated.
|
||||
type ProxyTokenCreated struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Id string `json:"id"`
|
||||
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||
Name string `json:"name"`
|
||||
|
||||
// PlainToken The plain text token (shown only once)
|
||||
PlainToken string `json:"plain_token"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
// ProxyTokenRequest defines model for ProxyTokenRequest.
|
||||
type ProxyTokenRequest struct {
|
||||
// ExpiresIn Token expiration in seconds (0 = never expires)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
|
||||
// Name Human-readable token name
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Resource defines model for Resource.
|
||||
@@ -4618,6 +4656,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
|
||||
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
|
||||
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
|
||||
|
||||
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
|
||||
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
|
||||
|
||||
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
|
||||
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user