Compare commits

...

3 Commits

Author SHA1 Message Date
Yury Gargay
5f1796e1c0 Fix tests 2024-01-04 18:52:28 +01:00
Yury Gargay
13c9765c6e Allow tests 2024-01-04 16:24:12 +01:00
Yury Gargay
b0729c5944 Try PostgreSQL store 2024-01-04 16:22:33 +01:00
7 changed files with 942 additions and 9 deletions

View File

@@ -4,6 +4,7 @@ on:
push:
branches:
- main
- yury/add-postgresql-store
pull_request:
concurrency:
@@ -14,8 +15,20 @@ jobs:
test:
strategy:
matrix:
store: ['jsonfile', 'sqlite']
store: ['jsonfile', 'sqlite', 'postgresql']
runs-on: macos-latest
services:
postgres:
image: postgres
env:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: integrations
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
steps:
- name: Install Go
uses: actions/setup-go@v4

View File

@@ -4,6 +4,7 @@ on:
push:
branches:
- main
- yury/add-postgresql-store
pull_request:
concurrency:
@@ -15,8 +16,20 @@ jobs:
strategy:
matrix:
arch: ['386','amd64']
store: ['jsonfile', 'sqlite']
store: ['jsonfile', 'sqlite', 'postgresql']
runs-on: ubuntu-latest
services:
postgres:
image: postgres
env:
POSTGRES_PASSWORD: postgres
POSTGRES_DB: integrations
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
steps:
- name: Install Go
uses: actions/setup-go@v4

8
go.mod
View File

@@ -44,6 +44,7 @@ require (
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/lib/pq v1.10.9
github.com/libp2p/go-netroute v0.2.0
github.com/magiconair/properties v1.8.5
github.com/mattn/go-sqlite3 v1.14.17
@@ -77,8 +78,9 @@ require (
golang.org/x/term v0.13.0
google.golang.org/api v0.126.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.4
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.4
gorm.io/gorm v1.25.5
)
require (
@@ -116,6 +118,9 @@ require (
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
github.com/hashicorp/go-uuid v1.0.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/native v1.0.0 // indirect
@@ -151,7 +156,6 @@ require (
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect

15
go.sum
View File

@@ -382,6 +382,12 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt
github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackmordaunt/icns v0.0.0-20181231085925-4f16af745526/go.mod h1:UQkeMHVoNcyXYq9otUupF7/h/2tmHlhrS2zw7ZVvUqc=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
@@ -434,6 +440,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE=
github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI=
github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc=
@@ -1179,7 +1187,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
@@ -1209,10 +1216,12 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g=
gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw=
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g=
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=

View File

@@ -0,0 +1,534 @@
package server
import (
"bytes"
"context"
"database/sql"
"encoding/base64"
"encoding/gob"
"fmt"
"reflect"
"runtime"
"strings"
"sync"
"time"
_ "github.com/lib/pq"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
// PostgresqlStore represents an account storage backed by a Postgres DB persisted to disk
type PostgresqlStore struct {
db *gorm.DB
dsn string
accountLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
}
// GobSerializer gob serializer
type GobBase64Serializer struct{}
// Scan implements serializer interface with base64 encoding
func (GobBase64Serializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
case string:
bytesValue = []byte(v)
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
if len(bytesValue) > 0 {
var decoded []byte
decoded, err = base64.StdEncoding.DecodeString(string(bytesValue))
if err == nil {
decoder := gob.NewDecoder(bytes.NewBuffer(decoded))
err = decoder.Decode(fieldValue.Interface())
}
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobBase64Serializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return base64.StdEncoding.EncodeToString(buf.Bytes()), err
}
// NewPostgresqlStore restores a store from the file located in the datadir
func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*PostgresqlStore, error) {
schema.RegisterSerializer("gob", GobBase64Serializer{})
sqlDB, err := sql.Open("postgres", dsn)
if err != nil {
return nil, err
}
db, err := gorm.Open(postgres.New(postgres.Config{
Conn: sqlDB,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
if err != nil {
return nil, err
}
//sql, err := db.DB()
//if err != nil {
// return nil, err
//}
conns := runtime.NumCPU()
sqlDB.SetMaxOpenConns(conns) // TODO: make it configurable
err = db.AutoMigrate(
&Account{}, &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{},
&Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{},
)
if err != nil {
return nil, err
}
return &PostgresqlStore{db: db, dsn: dsn, metrics: metrics, installationPK: 1}, nil
}
// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores PostgreSQL DB in the file located in datadir
func NewPostgresqlStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*PostgresqlStore, error) {
store, err := NewPostgresqlStore(dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(filestore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range filestore.GetAllAccounts() {
err := store.SaveAccount(account)
if err != nil {
return nil, err
}
}
return store, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *PostgresqlStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.Debugf("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
return unlock
}
func (s *PostgresqlStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *PostgresqlStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
}
for id, peer := range account.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
}
for id, user := range account.Users {
user.Id = id
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
}
for id, rule := range account.Rules {
rule.ID = id
account.RulesG = append(account.RulesG, *rule)
}
for id, route := range account.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the PostgreSQL", took.Milliseconds())
return err
}
func (s *PostgresqlStore) DeleteAccount(account *Account) error {
start := time.Now()
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to delete an account to the PostgreSQL", took.Milliseconds())
return err
}
func (s *PostgresqlStore) SaveInstallationID(ID string) error {
installation := installation{InstallationIDValue: ID}
installation.ID = uint(s.installationPK)
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
}
func (s *PostgresqlStore) GetInstallationID() string {
var installation installation
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
return ""
}
return installation.InstallationIDValue
}
func (s *PostgresqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peer.Status = &peerStatus
return s.db.Save(peer).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in PostgreSQL
func (s *PostgresqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil
}
// DeleteTokenID2UserIDIndex is noop in PostgreSQL
func (s *PostgresqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
return nil
}
func (s *PostgresqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
var account Account
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
// TODO: rework to not call GetAccount
return s.GetAccount(account.Id)
}
func (s *PostgresqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if key.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(key.AccountID)
}
func (s *PostgresqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return token.ID, nil
}
func (s *PostgresqlStore) GetUserByTokenID(tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, "id = ?", tokenID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user User
result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
return &user, nil
}
func (s *PostgresqlStore) GetAllAccounts() (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
if result.Error != nil {
return all
}
for _, account := range accounts {
if acc, err := s.GetAccount(account.Id); err == nil {
all = append(all, acc)
}
}
return all
}
func (s *PostgresqlStore) GetAccount(accountID string) (*Account, error) {
var account Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, "id = ?", accountID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found")
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*PolicyRule
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
account.Rules = make(map[string]*Rule, len(account.RulesG))
for _, rule := range account.RulesG {
account.Rules[rule.ID] = rule.Copy()
}
account.RulesG = nil
account.Routes = make(map[string]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
return &account, nil
}
func (s *PostgresqlStore) GetAccountByUser(userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, "id = ?", userID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if user.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(user.AccountID)
}
func (s *PostgresqlStore) GetAccountByPeerID(peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if peer.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(peer.AccountID)
}
func (s *PostgresqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if peer.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(peer.AccountID)
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *PostgresqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
if result.Error != nil {
return status.Errorf(status.NotFound, "user %s not found", userID)
}
user.LastLogin = lastLogin
return s.db.Save(user).Error
}
// Close is noop in PostgreSQL
func (s *PostgresqlStore) Close() error {
return nil
}
// GetStoreEngine returns PostgresqlStoreEngine
func (s *PostgresqlStore) GetStoreEngine() StoreEngine {
return PostgresqlStoreEngine
}

View File

@@ -0,0 +1,354 @@
package server
import (
"fmt"
"math/rand"
"net"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/postgres"
"gorm.io/gorm"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/util"
)
func TestPostgresql_NewStore(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStore(t)
defer cleanup()
if len(store.GetAllAccounts()) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
}
func TestPostgresql_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStore(t)
defer cleanup()
account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err := store.SaveAccount(account)
require.NoError(t, err)
account2 := newAccountWithId("account_id2", "testuser2", "")
setupKey = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
SetupKey: "peerkeysetupkey2",
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account2)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil {
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByUser("testuser"); a == nil {
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByPeerID("testpeer"); a == nil {
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil {
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
}
}
func TestPostgresql_DeleteAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStore(t)
defer cleanup()
testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId("account_id", testUserID, "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
err := store.SaveAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
err = store.DeleteAccount(account)
require.NoError(t, err)
if len(store.GetAllAccounts()) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
}
_, err = store.GetAccountByPeerPubKey("peerkey")
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key")
_, err = store.GetAccountByUser("testuser")
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user")
_, err = store.GetAccountByPeerID("testpeer")
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id")
_, err = store.GetAccountBySetupKey(setupKey.Key)
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key")
_, err = store.GetAccount(account.Id)
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
for _, policy := range account.Policies {
var rules []*PolicyRule
err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
}
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
}
}
func TestPostgresql_SavePeerStatus(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json")
defer cleanup()
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
// save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
// save new status of existing peer
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
ID: "testpeer",
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account)
require.NoError(t, err)
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(account.Id)
require.NoError(t, err)
actual := account.Peers["testpeer"].Status
assert.Equal(t, newStatus.Connected, actual.Connected)
// TODO: fix later
//assert.True(t, newStatus.LastSeen.Equal(actual.LastSeen))
}
func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json")
defer cleanup()
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
}
func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json")
defer cleanup()
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(hashed)
require.NoError(t, err)
require.Equal(t, id, token)
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The PostgreSQL store is not properly supported by Windows yet")
}
store, cleanup := newPostgresqlStoreFromFile(t, "testdata/store.json")
defer cleanup()
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func newPostgresqlStore(t *testing.T) (*PostgresqlStore, func()) {
t.Helper()
dbName := "store_" + randString(10)
postgresDsn := "host=localhost user=postgres password=postgres port=5432 sslmode=disable"
db, _ := gorm.Open(postgres.Open(postgresDsn), &gorm.Config{})
result := db.Exec(fmt.Sprintf("CREATE DATABASE %s ENCODING = 'UTF8'", dbName))
if result.Error != nil {
t.Fatalf("could not initialize postgresql store: %s", result.Error)
}
postgresDsn = fmt.Sprintf("%s dbname=%s ", postgresDsn, dbName)
cleanup := func() {
db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName))
}
store, err := NewPostgresqlStore(postgresDsn, nil)
if err != nil {
t.Fatalf("could not initialize postgresql store: %s", err)
}
require.NoError(t, err)
require.NotNil(t, store)
return store, cleanup
}
func randString(n int) string {
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyz1234567890")
b := make([]rune, n)
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}
func newPostgresqlStoreFromFile(t *testing.T, filename string) (*PostgresqlStore, func()) {
t.Helper()
storeDir := t.TempDir()
err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
require.NoError(t, err)
fStore, err := NewFileStore(storeDir, nil)
require.NoError(t, err)
dbName := "store_" + randString(10)
postgresDsn := "host=localhost user=postgres password=postgres port=5432 sslmode=disable"
db, _ := gorm.Open(postgres.Open(postgresDsn), &gorm.Config{})
result := db.Exec(fmt.Sprintf("CREATE DATABASE %s ENCODING = 'UTF8'", dbName))
if result.Error != nil {
t.Fatalf("could not initialize postgresql store: %s", result.Error)
}
postgresDsn = fmt.Sprintf("%s dbname=%s ", postgresDsn, dbName)
cleanup := func() {
db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName))
}
store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil)
require.NoError(t, err)
require.NotNil(t, store)
return store, cleanup
}
/*
func newAccount(store Store, id int) error {
str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
account := newAccountWithId(str, str+"-testuser", "example.com")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["p"+str] = &nbpeer.Peer{
Key: "peerkey" + str,
SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
return store.SaveAccount(account)
}
*/

View File

@@ -44,8 +44,9 @@ type Store interface {
type StoreEngine string
const (
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
PostgresqlStoreEngine StoreEngine = "postgresql"
)
func getStoreEngineFromEnv() StoreEngine {
@@ -76,6 +77,9 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S
case SqliteStoreEngine:
log.Info("using SQLite store engine")
return NewSqliteStore(dataDir, metrics)
case PostgresqlStoreEngine:
log.Info("using PostgreSQL store engine")
return NewPostgresqlStore(dataDir, metrics) // dataDir is dsn
default:
return nil, fmt.Errorf("unsupported kind of store %s", kind)
}
@@ -94,6 +98,8 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro
return fstore, nil
case SqliteStoreEngine:
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
case PostgresqlStoreEngine:
return NewPostgresqlStoreFromFileStore(fstore, dataDir, metrics) // dataDir is dsn
default:
return nil, fmt.Errorf("unsupported store engine %s", kind)
}