mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 23:51:55 -04:00
Compare commits
4 Commits
swap-dns-p
...
wasm-debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
063cbdc6d8 | ||
|
|
291e640b28 | ||
|
|
efb954b7d6 | ||
|
|
cac9326d3d |
@@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -30,6 +31,11 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -258,18 +264,40 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
// With lazy connections, the connection is established on first traffic.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
|
||||
|
||||
// Check context status upfront
|
||||
if ctx.Err() != nil {
|
||||
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when DialContext is called. The netstack
|
||||
// dial triggers WireGuard traffic which activates the lazy connection.
|
||||
|
||||
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
|
||||
conn, err := nsnet.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("embed.Dial: successfully connected to %s", address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
@@ -368,6 +396,35 @@ func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// WaitForPeerConnection waits for a peer with the given IP to be connected.
|
||||
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
|
||||
logrus.Infof("Waiting for peer %s to be connected", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
|
||||
case <-ticker.C:
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
logrus.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range status.Peers {
|
||||
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
|
||||
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
@@ -381,9 +438,8 @@ func (c *Client) SetLogLevel(levelStr string) error {
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect != nil {
|
||||
connect.SetLogLevel(level)
|
||||
}
|
||||
// Note: ConnectClient doesn't have SetLogLevel method
|
||||
_ = connect
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
|
||||
}
|
||||
|
||||
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Debugf("dialing %s %s", network, addr)
|
||||
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
|
||||
conn, err := d.net.Dial(network, addr)
|
||||
if err != nil {
|
||||
log.Debugf("failed to deal connection: %s", err)
|
||||
log.Warnf("NSDialer.Dial failed: %s", err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
|
||||
@@ -185,7 +185,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
@@ -226,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -256,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -278,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -304,7 +304,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
@@ -320,7 +320,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
@@ -2052,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||
|
||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
// Test that priority constants are ordered correctly
|
||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
|
||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||
|
||||
|
||||
@@ -23,11 +23,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
@@ -169,6 +171,9 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when Dial is called in ssh.Connect().
|
||||
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
||||
@@ -399,20 +404,79 @@ func createStatusSummaryMethod(client *netbird.Client) js.Func {
|
||||
// createStatusDetailMethod creates the statusDetail method
|
||||
func createStatusDetailMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Info("statusDetail called")
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
log.Info("statusDetail: getting overview")
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Errorf("statusDetail: getStatusOverview failed: %v", err)
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("statusDetail: generating full detail summary")
|
||||
detail := overview.FullDetailSummary()
|
||||
log.Infof("statusDetail: detail length=%d", len(detail))
|
||||
js.Global().Get("console").Call("log", detail)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createWaitForPeerConnectionMethod creates a method that waits for a peer to be connected
|
||||
func createWaitForPeerConnectionMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if len(args) < 1 {
|
||||
reject.Invoke(js.ValueOf("peer IP address required"))
|
||||
return
|
||||
}
|
||||
|
||||
peerIP := args[0].String()
|
||||
timeoutMs := int(defaultPeerConnectionTimeout.Milliseconds())
|
||||
if len(args) > 1 && !args[1].IsUndefined() && !args[1].IsNull() {
|
||||
timeoutMs = args[1].Int()
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutMs) * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
log.Infof("Waiting for peer %s to be connected (timeout: %v)", peerIP, timeout)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("timeout waiting for peer %s to connect", peerIP)))
|
||||
return
|
||||
case <-ticker.C:
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting status: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range overview.Peers.Details {
|
||||
if peer.IP == peerIP && peer.Status == "Connected" {
|
||||
log.Infof("Peer %s is now connected (type: %s)", peerIP, peer.ConnType)
|
||||
resolve.Invoke(js.ValueOf(map[string]interface{}{
|
||||
"ip": peer.IP,
|
||||
"status": peer.Status,
|
||||
"connType": peer.ConnType,
|
||||
}))
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
|
||||
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
@@ -471,12 +535,51 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
go handler(resolve, reject)
|
||||
// Wrap reject to always log the error
|
||||
loggingReject := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) > 0 {
|
||||
log.Errorf("Promise rejected: %v", args[0])
|
||||
js.Global().Get("console").Call("error", "WASM Promise rejected:", args[0])
|
||||
}
|
||||
reject.Invoke(args[0])
|
||||
return nil
|
||||
})
|
||||
|
||||
go handler(resolve, loggingReject.Value)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// waitForPeerConnected waits for a peer with the given IP to be connected
|
||||
func waitForPeerConnected(ctx context.Context, client *netbird.Client, peerIP string) error {
|
||||
log.Infof("Waiting for peer %s to be connected before operation", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect", peerIP)
|
||||
case <-ticker.C:
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range overview.Peers.Details {
|
||||
if peer.IP == peerIP && peer.Status == "Connected" {
|
||||
log.Infof("Peer %s is now connected (type: %s), proceeding with operation", peerIP, peer.ConnType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createDetectSSHServerMethod creates the SSH server detection method
|
||||
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -499,6 +602,10 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when Dial is called. Waiting would cause
|
||||
// a deadlock since lazy connections only open on traffic.
|
||||
|
||||
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
@@ -525,6 +632,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["waitForPeerConnection"] = createWaitForPeerConnectionMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
|
||||
@@ -549,7 +657,10 @@ func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return
|
||||
}
|
||||
|
||||
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
|
||||
// Force trace logging for debugging - must be set BEFORE netbird.New
|
||||
// as New() will call logrus.SetLevel with options.LogLevel
|
||||
options.LogLevel = "trace"
|
||||
if err := util.InitLog("trace", util.LogConsole); err != nil {
|
||||
log.Warnf("Failed to initialize logging: %v", err)
|
||||
}
|
||||
|
||||
|
||||
25
management/server/cache/idp.go
vendored
25
management/server/cache/idp.go
vendored
@@ -26,6 +26,8 @@ type UserDataCache interface {
|
||||
Get(ctx context.Context, key string) (*idp.UserData, error)
|
||||
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
|
||||
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
|
||||
}
|
||||
|
||||
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
||||
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
|
||||
return u.cache.Delete(ctx, key)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
|
||||
var users []*idp.UserData
|
||||
v, err := u.cache.Get(ctx, key, &users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case []*idp.UserData:
|
||||
return v, nil
|
||||
case *[]*idp.UserData:
|
||||
return *v, nil
|
||||
case []byte:
|
||||
return unmarshalUserData(v)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
|
||||
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
|
||||
}
|
||||
|
||||
// NewUserDataCache creates a new UserDataCacheImpl object.
|
||||
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
||||
simpleCache := cache.New[any](store)
|
||||
|
||||
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||
return strings.Contains(ua, "terraform")
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
|
||||
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test various Terraform user agent formats
|
||||
terraformUserAgents := []string{
|
||||
"Terraform/1.5.0",
|
||||
"terraform/1.0.0",
|
||||
"Terraform-Provider/2.0.0",
|
||||
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||
}
|
||||
|
||||
for _, userAgent := range terraformUserAgents {
|
||||
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
accountUsers := []*types.User{}
|
||||
switch {
|
||||
case allowed:
|
||||
start := time.Now()
|
||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
|
||||
case user != nil && user.AccountID == accountID:
|
||||
accountUsers = append(accountUsers, user)
|
||||
default:
|
||||
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
filtered := make(map[string]*idp.UserData, len(accountUsers))
|
||||
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
|
||||
start := time.Now()
|
||||
|
||||
integrationKeys := make(map[string]struct{})
|
||||
for _, user := range accountUsers {
|
||||
if user.Issued == types.UserIssuedIntegration {
|
||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
||||
users[user.Id] = true
|
||||
continue
|
||||
}
|
||||
usersFromIntegration = append(usersFromIntegration, info)
|
||||
integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
|
||||
continue
|
||||
}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||
}
|
||||
}
|
||||
|
||||
for key := range integrationKeys {
|
||||
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
|
||||
continue
|
||||
}
|
||||
for _, ud := range usersData {
|
||||
filtered[ud.ID] = ud
|
||||
}
|
||||
}
|
||||
|
||||
for _, ud := range filtered {
|
||||
usersFromIntegration = append(usersFromIntegration, ud)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
|
||||
start = time.Now()
|
||||
queriedUsers, err = am.lookupCache(ctx, users, accountID)
|
||||
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cacheManager := am.GetExternalCacheManager()
|
||||
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute)
|
||||
tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
|
||||
cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
|
||||
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||
|
||||
Reference in New Issue
Block a user