Files
netbird/management/server/http/testing/integration/accounts_handler_integration_test.go
2026-03-26 16:59:49 +01:00

239 lines
7.4 KiB
Go

//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Accounts_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all accounts", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/accounts", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Account{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
account := got[0]
assert.Equal(t, "test.com", account.Domain)
assert.Equal(t, "private", account.DomainCategory)
assert.Equal(t, true, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, 86400, account.Settings.PeerLoginExpiration)
assert.Equal(t, false, account.Settings.RegularUsersViewBlocked)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Accounts_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
trueVal := true
falseVal := false
tt := []struct {
name string
expectedStatus int
requestBody *api.AccountRequest
verifyResponse func(t *testing.T, account *api.Account)
verifyDB func(t *testing.T, account *types.Account)
}{
{
name: "Disable peer login expiration",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: 86400,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, false, account.Settings.PeerLoginExpirationEnabled)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, false, dbAccount.Settings.PeerLoginExpirationEnabled)
},
},
{
name: "Update peer login expiration to 48h",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 172800,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, 172800, account.Settings.PeerLoginExpiration)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, 172800*time.Second, dbAccount.Settings.PeerLoginExpiration)
},
},
{
name: "Enable regular users view blocked",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
RegularUsersViewBlocked: true,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.Equal(t, true, account.Settings.RegularUsersViewBlocked)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.RegularUsersViewBlocked)
},
},
{
name: "Enable groups propagation",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
GroupsPropagationEnabled: &trueVal,
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.NotNil(t, account.Settings.GroupsPropagationEnabled)
assert.Equal(t, true, *account.Settings.GroupsPropagationEnabled)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.GroupsPropagationEnabled)
},
},
{
name: "Enable JWT groups",
requestBody: &api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 86400,
GroupsPropagationEnabled: &falseVal,
JwtGroupsEnabled: &trueVal,
JwtGroupsClaimName: stringPointer("groups"),
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, account *api.Account) {
t.Helper()
assert.NotNil(t, account.Settings.JwtGroupsEnabled)
assert.Equal(t, true, *account.Settings.JwtGroupsEnabled)
assert.NotNil(t, account.Settings.JwtGroupsClaimName)
assert.Equal(t, "groups", *account.Settings.JwtGroupsClaimName)
},
verifyDB: func(t *testing.T, dbAccount *types.Account) {
t.Helper()
assert.Equal(t, true, dbAccount.Settings.JWTGroupsEnabled)
assert.Equal(t, "groups", dbAccount.Settings.JWTGroupsClaimName)
},
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/accounts/{accountId}", "{accountId}", testing_tools.TestAccountId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
got := &api.Account{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, testing_tools.TestAccountId, got.Id)
assert.Equal(t, "test.com", got.Domain)
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbAccount := testing_tools.VerifyAccountSettings(t, db)
tc.verifyDB(t, dbAccount)
})
}
}
}
func stringPointer(s string) *string {
return &s
}